From 8cbcc33876c773722163b2259644037bbb259bd1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 9 Feb 2018 12:54:57 +0800 Subject: [PATCH 001/593] [SPARK-23186][SQL] Initialize DriverManager first before loading JDBC Drivers ## What changes were proposed in this pull request? Since some JDBC Drivers have class initialization code to call `DriverManager`, we need to initialize `DriverManager` first in order to avoid potential executor-side **deadlock** situations like the following (or [STORM-2527](https://issues.apache.org/jira/browse/STORM-2527)). ``` Thread 9587: (state = BLOCKED) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame; information may be imprecise) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - java.util.ServiceLoader$LazyIterator.nextService() bci=119, line=380 (Interpreted frame) - java.util.ServiceLoader$LazyIterator.next() bci=11, line=404 (Interpreted frame) - java.util.ServiceLoader$1.next() bci=37, line=480 (Interpreted frame) - java.sql.DriverManager$2.run() bci=21, line=603 (Interpreted frame) - java.sql.DriverManager$2.run() bci=1, line=583 (Interpreted frame) - java.security.AccessController.doPrivileged(java.security.PrivilegedAction) bci=0 (Compiled frame) - java.sql.DriverManager.loadInitialDrivers() bci=27, line=583 (Interpreted frame) - java.sql.DriverManager.() bci=32, line=101 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getConnection(java.lang.String, java.lang.Integer, java.lang.String, java.util.Properties) bci=12, line=98 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getInputConnection(org.apache.hadoop.conf.Configuration, java.util.Properties) bci=22, line=57 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.getQueryPlan(org.apache.hadoop.mapreduce.JobContext, org.apache.hadoop.conf.Configuration) bci=61, line=116 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.createRecordReader(org.apache.hadoop.mapreduce.InputSplit, org.apache.hadoop.mapreduce.TaskAttemptContext) bci=10, line=71 (Interpreted frame) - org.apache.spark.rdd.NewHadoopRDD$$anon$1.(org.apache.spark.rdd.NewHadoopRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=233, line=156 (Interpreted frame) Thread 9170: (state = BLOCKED) - org.apache.phoenix.jdbc.PhoenixDriver.() bci=35, line=125 (Interpreted frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(java.lang.String) bci=89, line=46 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=7, line=53 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=1, line=52 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$$anon$1.(org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=81, line=347 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD.compute(org.apache.spark.Partition, org.apache.spark.TaskContext) bci=7, line=339 (Interpreted frame) ``` ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #20359 from dongjoon-hyun/SPARK-23186. --- .../sql/execution/datasources/jdbc/DriverRegistry.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7a6c0f9fed2f9..1723596de1db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -32,6 +32,13 @@ import org.apache.spark.util.Utils */ object DriverRegistry extends Logging { + /** + * Load DriverManager first to avoid any race condition between + * DriverManager static initialization block and specific driver class's + * static initialization block. e.g. PhoenixDriver + */ + DriverManager.getDrivers + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty def register(className: String): Unit = { From 4b4ee2601079f12f8f410a38d2081793cbdedc14 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 9 Feb 2018 14:21:10 +0800 Subject: [PATCH 002/593] [SPARK-23328][PYTHON] Disallow default value None in na.replace/replace when 'to_replace' is not a dictionary ## What changes were proposed in this pull request? This PR proposes to disallow default value None when 'to_replace' is not a dictionary. It seems weird we set the default value of `value` to `None` and we ended up allowing the case as below: ```python >>> df.show() ``` ``` +----+------+-----+ | age|height| name| +----+------+-----+ | 10| 80|Alice| ... ``` ```python >>> df.na.replace('Alice').show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` **After** This PR targets to disallow the case above: ```python >>> df.na.replace('Alice').show() ``` ``` ... TypeError: value is required when to_replace is not a dictionary. ``` while we still allow when `to_replace` is a dictionary: ```python >>> df.na.replace({'Alice': None}).show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` ## How was this patch tested? Manually tested, tests were added in `python/pyspark/sql/tests.py` and doctests were fixed. Author: hyukjinkwon Closes #20499 from HyukjinKwon/SPARK-19454-followup. --- docs/sql-programming-guide.md | 1 + python/pyspark/__init__.py | 1 + python/pyspark/_globals.py | 70 +++++++++++++++++++++++++++++++++ python/pyspark/sql/dataframe.py | 26 +++++++++--- python/pyspark/sql/tests.py | 11 +++--- 5 files changed, 99 insertions(+), 10 deletions(-) create mode 100644 python/pyspark/_globals.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0e221b39cc34..eab4030ee25d2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1929,6 +1929,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 4d142c91629cc..58218918693ca 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,6 +54,7 @@ from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ +from pyspark._globals import _NoValue def since(version): diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py new file mode 100644 index 0000000000000..8e6099db09963 --- /dev/null +++ b/python/pyspark/_globals.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if pyspark itself is reloaded. In particular, a function like the following +will still work correctly after pyspark is reloaded: + + def foo(arg=pyspark._NoValue): + if arg is pyspark._NoValue: + ... + +See gh-7844 for a discussion of the reload problem that motivated this module. + +Note that this approach is taken after from NumPy. +""" + +__ALL__ = ['_NoValue'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading pyspark._globals is not allowed') +_is_loaded = True + + +class _NoValueType(object): + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + deprecated keyword in order to check if it has been given a user defined + value. + + This class was copied from NumPy. + """ + __instance = None + + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super(_NoValueType, cls).__new__(cls) + return cls.__instance + + # needed for python 2 to preserve identity through a pickle + def __reduce__(self): + return (self.__class__, ()) + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8ec24db8717b2..faee870a2d2e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,7 @@ import warnings -from pyspark import copy_func, since +from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer @@ -1532,7 +1532,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1545,8 +1545,8 @@ def replace(self, to_replace, value=None, subset=None): :param to_replace: bool, int, long, float, string, list or dict. Value to be replaced. - If the value is a dict, then `value` is ignored and `to_replace` must be a - mapping between a value and a replacement. + If the value is a dict, then `value` is ignored or can be omitted, and `to_replace` + must be a mapping between a value and a replacement. :param value: bool, int, long, float, string, list or None. The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. @@ -1577,6 +1577,16 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ + >>> df4.na.replace({'Alice': None}).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1587,6 +1597,12 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ """ + if value is _NoValue: + if isinstance(to_replace, dict): + value = None + else: + raise TypeError("value argument is required when to_replace is not a dictionary.") + # Helper functions def all_of(types): """Given a type or tuple of types and a sequence of xs @@ -2047,7 +2063,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 90ff084fed55e..6ace16955000d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2243,11 +2243,6 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) - # replace list while value is not given (default to None) - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - self.assertTupleEqual(row, (None, 10, 80.0)) - # replace string with None and then drop None rows row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() @@ -2283,6 +2278,12 @@ def test_replace(self): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + with self.assertRaisesRegexp( + TypeError, + 'value argument is required when to_replace is not a dictionary.'): + self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From f77270b8811bbd8956d0c08fa556265d2c5ee20e Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 9 Feb 2018 08:45:06 -0600 Subject: [PATCH 003/593] [SPARK-23358][CORE] When the number of partitions is greater than 2^28, it will result in an error result ## What changes were proposed in this pull request? In the `checkIndexAndDataFile`,the `blocks` is the ` Int` type, when it is greater than 2^28, `blocks*8` will overflow, and this will result in an error result. In fact, `blocks` is actually the number of partitions. ## How was this patch tested? Manual test Author: liuxian Closes #20544 from 10110346/overflow. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index c5f3f6e2b42b6..d88b25cc7e258 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -84,7 +84,7 @@ private[spark] class IndexShuffleBlockResolver( */ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8) { + if (index.length() != (blocks + 1) * 8L) { return null } val lengths = new Array[Long](blocks) From 0fc26313f8071cdcb4ccd67bb1d6942983199d36 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 9 Feb 2018 08:46:27 -0600 Subject: [PATCH 004/593] [SPARK-21860][CORE][FOLLOWUP] fix java style error ## What changes were proposed in this pull request? #19077 introduced a Java style error (too long line). Quick fix. ## How was this patch tested? running `./dev/lint-java` Author: Marco Gaido Closes #20558 from mgaido91/SPARK-21860. --- .../test/java/org/apache/spark/unsafe/PlatformUtilSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 71c53d35dcab8..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -139,7 +139,8 @@ public void memoryDebugFillEnabledInTest() { @Test public void heapMemoryReuse() { MemoryAllocator heapMem = new HeapMemoryAllocator(); - // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // allocate new memory every time. MemoryBlock onheap1 = heapMem.allocate(513); Object obj1 = onheap1.getBaseObject(); heapMem.free(onheap1); From 7f10cf83f311526737fc96d5bb8281d12e41932f Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Fri, 9 Feb 2018 11:21:20 -0800 Subject: [PATCH 005/593] [SPARK-16501][MESOS] Allow providing Mesos principal & secret via files This commit modifies the Mesos submission client to allow the principal and secret to be provided indirectly via files. The path to these files can be specified either via Spark configuration or via environment variable. Assuming these files are appropriately protected by FS/OS permissions this means we don't ever leak the actual values in process info like ps Environment variable specification is useful because it allows you to interpolate the location of this file when using per-user Mesos credentials. For some background as to why we have taken this approach I will briefly describe our set up. On our systems we provide each authorised user account with their own Mesos credentials to provide certain security and audit guarantees to our customers. These credentials are managed by a central Secret management service. In our `spark-env.sh` we determine the appropriate secret and principal files to use depending on the user who is invoking Spark hence the need to inject these via environment variables as well as by configuration properties. So we set these environment variables appropriately and our Spark read in the contents of those files to authenticate itself with Mesos. This is functionality we have been using it in production across multiple customer sites for some time. This has been in the field for around 18 months with no reported issues. These changes have been sufficient to meet our customer security and audit requirements. We have been building and deploying custom builds of Apache Spark with various minor tweaks like this which we are now looking to contribute back into the community in order that we can rely upon stock Apache Spark builds and stop maintaining our own internal fork. Author: Rob Vesse Closes #20167 from rvesse/SPARK-16501. --- docs/running-on-mesos.md | 40 ++++- .../cluster/mesos/MesosSchedulerUtils.scala | 55 ++++-- .../mesos/MesosSchedulerUtilsSuite.scala | 161 +++++++++++++++++- 3 files changed, 238 insertions(+), 18 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 2bb5ecf1b8509..8e58892e2689f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -82,6 +82,27 @@ a Spark driver program configured to connect to Mesos. Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure `spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. +## Authenticating to Mesos + +When Mesos Framework authentication is enabled it is necessary to provide a principal and secret by which to authenticate Spark to Mesos. Each Spark job will register with Mesos as a separate framework. + +Depending on your deployment environment you may wish to create a single set of framework credentials that are shared across all users or create framework credentials for each user. Creating and managing framework credentials should be done following the Mesos [Authentication documentation](http://mesos.apache.org/documentation/latest/authentication/). + +Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. + +Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. + +### Credential Specification Preference Order + +Please note that if you specify multiple ways to obtain the credentials then the following preference order applies. Spark will use the first valid value found and any subsequent values are ignored: + +- `spark.mesos.principal` configuration setting +- `SPARK_MESOS_PRINCIPAL` environment variable +- `spark.mesos.principal.file` configuration setting +- `SPARK_MESOS_PRINCIPAL_FILE` environment variable + +An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -427,7 +448,14 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.principal (none) - Set the principal with which Spark framework will use to authenticate with Mesos. + Set the principal with which Spark framework will use to authenticate with Mesos. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL`. + + + + spark.mesos.principal.file + (none) + + Set the file containing the principal with which Spark framework will use to authenticate with Mesos. Allows specifying the principal indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL_FILE`. @@ -435,7 +463,15 @@ See the [configuration page](configuration.html) for information on Spark config (none) Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when - authenticating with the registry. + authenticating with the registry. You can also specify this via the environment variable `SPARK_MESOS_SECRET`. + + + + spark.mesos.secret.file + (none) + + Set the file containing the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when + authenticating with the registry. Allows for specifying the secret indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_SECRET_FILE`. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index e75450369ad85..ecbcc960fc5a0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.File +import java.nio.charset.StandardCharsets import java.util.{List => JList} import java.util.concurrent.CountDownLatch @@ -25,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter +import com.google.common.io.Files import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability @@ -71,26 +74,15 @@ trait MesosSchedulerUtils extends Logging { failoverTimeout: Option[Double] = None, frameworkId: Option[String] = None): SchedulerDriver = { val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) - val credBuilder = Credential.newBuilder() + fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS))) webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } frameworkId.foreach { id => fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) } - fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( - conf.get(DRIVER_HOST_ADDRESS))) - conf.getOption("spark.mesos.principal").foreach { principal => - fwInfoBuilder.setPrincipal(principal) - credBuilder.setPrincipal(principal) - } - conf.getOption("spark.mesos.secret").foreach { secret => - credBuilder.setSecret(secret) - } - if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { - throw new SparkException( - "spark.mesos.principal must be configured when spark.mesos.secret is set") - } + conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } @@ -98,6 +90,7 @@ trait MesosSchedulerUtils extends Logging { if (maxGpus > 0) { fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) } + val credBuilder = buildCredentials(conf, fwInfoBuilder) if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) @@ -106,6 +99,40 @@ trait MesosSchedulerUtils extends Logging { } } + def buildCredentials( + conf: SparkConf, + fwInfoBuilder: Protos.FrameworkInfo.Builder): Protos.Credential.Builder = { + val credBuilder = Credential.newBuilder() + conf.getOption("spark.mesos.principal") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL"))) + .orElse( + conf.getOption("spark.mesos.principal.file") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL_FILE"))) + .map { principalFile => + Files.toString(new File(principalFile), StandardCharsets.UTF_8) + } + ).foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET"))) + .orElse( + conf.getOption("spark.mesos.secret.file") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET_FILE"))) + .map { secretFile => + Files.toString(new File(secretFile), StandardCharsets.UTF_8) + } + ).foreach { secret => + credBuilder.setSecret(secret) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + credBuilder + } + /** * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. * This driver is expected to not be running. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 7df738958f85c..8d90e1a8591ad 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.{File, FileNotFoundException} + import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import org.apache.mesos.Protos.{Resource, Value} +import com.google.common.io.Files +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Value} import org.mockito.Mockito._ import org.scalatest._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.internal.config._ +import org.apache.spark.util.SparkConfWithEnv class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { @@ -237,4 +241,157 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} portsToUse.isEmpty shouldBe true } + + test("Principal specified via spark.mesos.principal") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", pFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "test-principal")) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> pFile.getAbsolutePath())) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE that does not exist") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> "/tmp/does-not-exist")) + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Secret specified via spark.mesos.secret") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", sFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_SECRET") { + val env = Map("SPARK_MESOS_SECRET" -> "my-secret") + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via SPARK_MESOS_SECRET_FILE") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + + val sFilePath = sFile.getAbsolutePath() + val env = Map("SPARK_MESOS_SECRET_FILE" -> sFilePath) + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Secret specified with no principal") { + val conf = new SparkConf() + conf.set("spark.mesos.secret", "my-secret") + + intercept[SparkException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "other-principal")) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Secret specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_SECRET" -> "other-secret")) + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } } From 557938e2839afce26a10a849a2a4be8fc4580427 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 9 Feb 2018 18:18:30 -0600 Subject: [PATCH 006/593] [MINOR][HIVE] Typo fixes ## What changes were proposed in this pull request? Typo fixes (with expanding a Hive property) ## How was this patch tested? local build. Awaiting Jenkins Author: Jacek Laskowski Closes #20550 from jaceklaskowski/hiveutils-typos. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 93f3f38e52aa9..c448c5a9821be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -304,7 +304,7 @@ private[spark] object HiveUtils extends Logging { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"Specify a valid path to the correct hive jars using ${HIVE_METASTORE_JARS.key} " + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } @@ -324,7 +324,7 @@ private[spark] object HiveUtils extends Logging { if (jars.length == 0) { throw new IllegalArgumentException( "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") + s"Please set ${HIVE_METASTORE_JARS.key}.") } logInfo( From 6d7c38330e68c7beb10f54eee8b4f607ee3c4136 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 9 Feb 2018 16:21:47 -0800 Subject: [PATCH 007/593] [SPARK-23275][SQL] fix the thread leaking in hive/tests ## What changes were proposed in this pull request? This is a follow up of https://github.com/apache/spark/pull/20441. The two lines actually can trigger the hive metastore bug: https://issues.apache.org/jira/browse/HIVE-16844 The two configs are not in the default `ObjectStore` properties, so any run hive commands after these two lines will set the `propsChanged` flag in the `ObjectStore.setConf` and then cause thread leaks. I don't think the two lines are very useful. They can be removed safely. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20562 from liufengdb/fix-omm. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 59708e7a0f2ff..19028939f3673 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -530,8 +530,6 @@ private[hive] class TestHiveSparkSession( // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 metadataHive.runSqlHive("set hive.table.parameters.default=") - metadataHive.runSqlHive("set datanucleus.cache.collections=true") - metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") From 97a224a855c4410b2dfb9c0bcc6aae583bd28e92 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 11 Feb 2018 01:08:02 +0900 Subject: [PATCH 008/593] [SPARK-23360][SQL][PYTHON] Get local timezone from environment via pytz, or dateutil. ## What changes were proposed in this pull request? Currently we use `tzlocal()` to get Python local timezone, but it sometimes causes unexpected behavior. I changed the way to get Python local timezone to use pytz if the timezone is specified in environment variable, or timezone file via dateutil . ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20559 from ueshin/issues/SPARK-23360/master. --- python/pyspark/sql/tests.py | 28 ++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 23 +++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6ace16955000d..1087c3fafdd16 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2868,6 +2868,34 @@ def test_create_dataframe_required_pandas_not_found(self): "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) + # Regression test for SPARK-23360 + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) + def test_create_dateframe_from_pandas_with_dst(self): + import pandas as pd + from datetime import datetime + + pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + + orig_env_tz = os.environ.get('TZ', None) + orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') + try: + tz = 'America/Los_Angeles' + os.environ['TZ'] = tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', tz) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + finally: + del os.environ['TZ'] + if orig_env_tz is not None: + os.environ['TZ'] = orig_env_tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 093dae5a22e1f..2599dc5fdc599 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1709,6 +1709,21 @@ def _check_dataframe_convert_date(pdf, schema): return pdf +def _get_local_timezone(): + """ Get local timezone using pytz with environment variable, or dateutil. + + If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone + string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and + it reads system configuration to know the system local timezone. + + See also: + - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 + - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 + """ + import os + return os.environ.get('TZ', 'dateutil/:') + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1721,7 +1736,7 @@ def _check_dataframe_localize_timestamps(pdf, timezone): require_minimum_pandas_version() from pandas.api.types import is_datetime64tz_dtype - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): @@ -1744,7 +1759,7 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() return s.dt.tz_localize(tz).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') @@ -1766,8 +1781,8 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): import pandas as pd from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - from_tz = from_timezone or 'tzlocal()' - to_tz = to_timezone or 'tzlocal()' + from_tz = from_timezone or _get_local_timezone() + to_tz = to_timezone or _get_local_timezone() # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert(to_tz).dt.tz_localize(None) From 0783876c81f212e1422a1b7786c26e3ac8e84f9f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Feb 2018 10:46:45 -0600 Subject: [PATCH 009/593] [SPARK-23344][PYTHON][ML] Add distanceMeasure param to KMeans ## What changes were proposed in this pull request? SPARK-22119 introduced a new parameter for KMeans, ie. `distanceMeasure`. The PR adds it also to the Python interface. ## How was this patch tested? added UTs Author: Marco Gaido Closes #20520 from mgaido91/SPARK-23344. --- python/pyspark/ml/clustering.py | 32 +++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 18 ++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 66fb00508522e..6448b76a0da88 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -403,17 +403,23 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) + self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, + distanceMeasure="euclidean") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -423,10 +429,12 @@ def _create_model(self, java_model): @keyword_only @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") Sets params for KMeans. """ @@ -475,6 +483,20 @@ def getInitSteps(self): """ return self.getOrDefault(self.initSteps) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 75d04785a0710..6d6737241e06e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -418,6 +418,9 @@ def test_kmeans_param(self): self.assertEqual(algo.getK(), 10) algo.setInitSteps(10) self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") def test_hasseed(self): noSeedSpecd = TestParams() @@ -1620,6 +1623,21 @@ def test_kmeans_summary(self): self.assertEqual(s.k, 2) +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + class OneVsRestTests(SparkSessionTestCase): def test_copy(self): From a34fce19bc0ee5a7e36c6ecba75d2aeb70fdcbc7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sun, 11 Feb 2018 17:31:35 +0900 Subject: [PATCH 010/593] [SPARK-23314][PYTHON] Add ambiguous=False when localizing tz-naive timestamps in Arrow codepath to deal with dst ## What changes were proposed in this pull request? When tz_localize a tz-naive timetamp, pandas will throw exception if the timestamp is during daylight saving time period, e.g., `2015-11-01 01:30:00`. This PR fixes this issue by setting `ambiguous=False` when calling tz_localize, which is the same default behavior of pytz. ## How was this patch tested? Add `test_timestamp_dst` Author: Li Jin Closes #20537 from icexelloss/SPARK-23314. --- python/pyspark/sql/tests.py | 39 +++++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 37 ++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1087c3fafdd16..4bc59fd99fca5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3670,6 +3670,21 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) + + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4311,6 +4326,18 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res1.collect()) self.assertEquals(expected.collect(), res2.collect()) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda x: x, 'timestamp') + result = df.withColumn('time', foo_udf(df.time)) + self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4482,6 +4509,18 @@ def test_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) + result = df.groupby('time').apply(foo_udf).sort('time') + self.assertPandasEqual(df.toPandas(), result.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 2599dc5fdc599..f7141b4549e4e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1759,8 +1759,38 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): + # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive + # timestamp is during the hour when the clock is adjusted backward during due to + # daylight saving time (dst). + # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to + # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize + # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either + # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). + # + # Here we explicit choose to use standard time. This matches the default behavior of + # pytz. + # + # Here are some code to help understand this behavior: + # >>> import datetime + # >>> import pandas as pd + # >>> import pytz + # >>> + # >>> t = datetime.datetime(2015, 11, 1, 1, 30) + # >>> ts = pd.Series([t]) + # >>> tz = pytz.timezone('America/New_York') + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=True) + # 0 2015-11-01 01:30:00-04:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=False) + # 0 2015-11-01 01:30:00-05:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> str(tz.localize(t)) + # '2015-11-01 01:30:00-05:00' tz = timezone or _get_local_timezone() - return s.dt.tz_localize(tz).dt.tz_convert('UTC') + return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: @@ -1788,8 +1818,9 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): return s.dt.tz_convert(to_tz).dt.tz_localize(None) elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) - if ts is not pd.NaT else pd.NaT) + return s.apply( + lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) else: return s From 8acb51f08b448628b65e90af3b268994f9550e45 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Feb 2018 18:55:38 +0900 Subject: [PATCH 011/593] [SPARK-23084][PYTHON] Add unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark ## What changes were proposed in this pull request? Added unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark, also updated the rangeBetween API ## How was this patch tested? did unit test on my local. Please let me know if I need to add unit test in tests.py Author: Huaxin Gao Closes #20400 from huaxingao/spark_23084. --- python/pyspark/sql/functions.py | 30 ++++++++++++++ python/pyspark/sql/window.py | 70 ++++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 05031f5ec87d7..9bb9c323a5a60 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -809,6 +809,36 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +@since(2.4) +def unboundedPreceding(): + """ + Window function: returns the special frame boundary that represents the first row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedPreceding()) + + +@since(2.4) +def unboundedFollowing(): + """ + Window function: returns the special frame boundary that represents the last row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedFollowing()) + + +@since(2.4) +def currentRow(): + """ + Window function: returns the special frame boundary that represents the current row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.currentRow()) + + # ---------------------- Date/Timestamp functions ------------------------------ @since(1.5) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 7ce27f9b102c0..bb841a9b9ff7c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,9 +16,11 @@ # import sys +if sys.version >= '3': + long = int from pyspark import since, SparkContext -from pyspark.sql.column import _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -120,20 +122,45 @@ def rangeBetween(start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). + + >>> from pyspark.sql import functions as F, SparkSession, Window + >>> spark = SparkSession.builder.getOrCreate() + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> window = Window.orderBy("id").partitionBy("category").rangeBetween( + ... F.currentRow(), F.lit(1)) + >>> df.withColumn("sum", F.sum("id").over(window)).show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| b| 3| + | 2| b| 5| + | 3| b| 3| + | 1| a| 4| + | 1| a| 4| + | 2| a| 2| + +---+--------+---+ """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -208,27 +235,34 @@ def rangeBetween(self, start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc return WindowSpec(self._jspec.rangeBetween(start, end)) def _test(): import doctest SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod() + (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: exit(-1) From eacb62fbbed317fd0e972102838af231385d54d8 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 11 Feb 2018 19:23:15 +0900 Subject: [PATCH 012/593] [SPARK-22624][PYSPARK] Expose range partitioning shuffle introduced by spark-22614 ## What changes were proposed in this pull request? Expose range partitioning shuffle introduced by spark-22614 ## How was this patch tested? Unit test in dataframe.py Please review http://spark.apache.org/contributing.html before opening a pull request. Author: xubo245 <601450868@qq.com> Closes #20456 from xubo245/SPARK22624_PysparkRangePartition. --- python/pyspark/sql/dataframe.py | 45 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 28 ++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index faee870a2d2e2..5cc8b63cdfadf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -667,6 +667,51 @@ def repartition(self, numPartitions, *cols): else: raise TypeError("numPartitions should be an int or Column") + @since("2.4.0") + def repartitionByRange(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is range partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + At least one partition-by expression must be specified. + When no explicit sort order is specified, "ascending nulls first" is assumed. + + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() + 2 + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + >>> df.repartitionByRange(1, "age").rdd.getNumPartitions() + 1 + >>> data = df.repartitionByRange("age") + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return ValueError("At least one partition-by expression must be specified.") + else: + return DataFrame( + self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions,) + cols + return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int, string or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4bc59fd99fca5..fe89bd0685027 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2148,6 +2148,34 @@ def test_expr(self): result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) + def test_repartitionByRange_dataframe(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + df1 = self.spark.createDataFrame( + [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) + df2 = self.spark.createDataFrame( + [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) + + # test repartitionByRange(numPartitions, *cols) + df3 = df1.repartitionByRange(2, "name", "age") + self.assertEqual(df3.rdd.getNumPartitions(), 2) + self.assertEqual(df3.rdd.first(), df2.rdd.first()) + self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(numPartitions, *cols) + df4 = df1.repartitionByRange(3, "name", "age") + self.assertEqual(df4.rdd.getNumPartitions(), 3) + self.assertEqual(df4.rdd.first(), df2.rdd.first()) + self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(*cols) + df5 = df1.repartitionByRange("name", "age") + self.assertEqual(df5.rdd.first(), df2.rdd.first()) + self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), From 4bbd7443ebb005f81ed6bc39849940ac8db3b3cc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 00:03:49 +0800 Subject: [PATCH 013/593] [SPARK-23376][SQL] creating UnsafeKVExternalSorter with BytesToBytesMap may fail ## What changes were proposed in this pull request? This is a long-standing bug in `UnsafeKVExternalSorter` and was reported in the dev list multiple times. When creating `UnsafeKVExternalSorter` with `BytesToBytesMap`, we need to create a `UnsafeInMemorySorter` to sort the data in `BytesToBytesMap`. The data format of the sorter and the map is same, so no data movement is required. However, both the sorter and the map need a point array for some bookkeeping work. There is an optimization in `UnsafeKVExternalSorter`: reuse the point array between the sorter and the map, to avoid an extra memory allocation. This sounds like a reasonable optimization, the length of the `BytesToBytesMap` point array is at least 4 times larger than the number of keys(to avoid hash collision, the hash table size should be at least 2 times larger than the number of keys, and each key occupies 2 slots). `UnsafeInMemorySorter` needs the pointer array size to be 4 times of the number of entries, so we are safe to reuse the point array. However, the number of keys of the map doesn't equal to the number of entries in the map, because `BytesToBytesMap` supports duplicated keys. This breaks the assumption of the above optimization and we may run out of space when inserting data into the sorter, and hit error ``` java.lang.IllegalStateException: There is no space for new record at org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.insertRecord(UnsafeInMemorySorter.java:239) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:149) ... ``` This PR fixes this bug by creating a new point array if the existing one is not big enough. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20561 from cloud-fan/bug. --- .../sql/execution/UnsafeKVExternalSorter.java | 31 +++++++++++---- .../UnsafeKVExternalSorterSuite.scala | 39 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index b0b5383a081a0..9eb03430a7db2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.unsafe.sort.*; @@ -98,19 +99,33 @@ public UnsafeKVExternalSorter( numElementsForSpillThreshold, canUseRadixSort); } else { - // The array will be used to do in-place sort, which require half of the space to be empty. - // Note: each record in the map takes two entries in the array, one is record pointer, - // another is the key prefix. - assert(map.numKeys() * 2 <= map.getArray().size() / 2); - // During spilling, the array in map will not be used, so we can borrow that and use it - // as the underlying array for in-memory sorter (it's always large enough). - // Since we will not grow the array, it's fine to pass `null` as consumer. + // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow + // that and use it as the pointer array for `UnsafeInMemorySorter`. + LongArray pointerArray = map.getArray(); + // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but + // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since + // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer + // array can hold all the entries in `BytesToBytesMap`. + // The pointer array will be used to do in-place sort, which requires half of the space to be + // empty. Note: each record in the map takes two entries in the pointer array, one is record + // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`. + // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key, + // so that we can always reuse the pointer array. + if (map.numValues() > pointerArray.size() / 4) { + // Here we ask the map to allocate memory, so that the memory manager won't ask the map + // to spill, if the memory is not enough. + pointerArray = map.allocateArray(map.numValues() * 4L); + } + + // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed + // to be large enough, it's fine to pass `null` as consumer because we won't allocate more + // memory. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( null, taskMemoryManager, comparatorSupplier.get(), prefixComparator, - map.getArray(), + pointerArray, canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 6af9f8b77f8d3..bf588d3bb7841 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.map.BytesToBytesMap /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -205,4 +206,42 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } + + test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") { + val memoryManager = new TestMemoryManager(new SparkConf()) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes()) + + // Key/value are a unsafe rows with a single int column + val schema = new StructType().add("i", IntegerType) + val key = new UnsafeRow(1) + key.pointTo(new Array[Byte](32), 32) + key.setInt(0, 1) + val value = new UnsafeRow(1) + value.pointTo(new Array[Byte](32), 32) + value.setInt(0, 2) + + for (_ <- 1 to 65) { + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + + // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` + // which has duplicated keys and the number of entries exceeds its capacity. + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + new UnsafeKVExternalSorter( + schema, + schema, + sparkContext.env.blockManager, + sparkContext.env.serializerManager, + taskMemoryManager.pageSizeBytes(), + Int.MaxValue, + map) + } finally { + TaskContext.unset() + } + } } From c0c902aedcf9ed24e482d873d766a7df63b964cb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 11 Feb 2018 20:15:30 -0600 Subject: [PATCH 014/593] [SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance ## What changes were proposed in this pull request? In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido Closes #20518 from mgaido91/SPARK-22119_followup. --- .../spark/mllib/clustering/KMeans.scala | 54 ++++++++++++++++--- .../spark/ml/clustering/KMeansSuite.scala | 15 +++++- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 607145cb59fba..3c4ba0bc60c7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -310,8 +310,7 @@ class KMeans private ( points.foreach { point => val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) - val sum = sums(bestCenter) - axpy(1.0, point.vector, sum) + distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) counts(bestCenter) += 1 } @@ -319,10 +318,9 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.mapValues { case (sum, count) => - scal(1.0 / count, sum) - new VectorWithNorm(sum) - }.collectAsMap() + }.collectAsMap().mapValues { case (sum, count) => + distanceMeasureInstance.centroid(sum, count) + } bcCenters.destroy(blocking = false) @@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable { v1: VectorWithNorm, v2: VectorWithNorm): Double + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } } @Since("2.4.0") @@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { * @return the cosine distance between the two input vectors */ override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e4506f23feb31..32830b39407ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == predictionsMap(Vectors.dense(-100.0, 90.0))) + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } + + test("KMeans with cosine distance is not supported for 0-length vectors") { + val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } test("read/write") { From 6efd5d117e98074d1b16a5c991fbd38df9aa196e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 11 Feb 2018 23:46:23 -0800 Subject: [PATCH 015/593] [SPARK-23390][SQL] Flaky Test Suite: FileBasedDataSourceSuite in Spark 2.3/hadoop 2.7 ## What changes were proposed in this pull request? This test only fails with sbt on Hadoop 2.7, I can't reproduce it locally, but here is my speculation by looking at the code: 1. FileSystem.delete doesn't delete the directory entirely, somehow we can still open the file as a 0-length empty file.(just speculation) 2. ORC intentionally allow empty files, and the reader fails during reading without closing the file stream. This PR improves the test to make sure all files are deleted and can't be opened. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20584 from cloud-fan/flaky-test. --- .../spark/sql/FileBasedDataSourceSuite.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 640d6b1583663..2e332362ea644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.FileNotFoundException + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -102,17 +104,27 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { def testIgnoreMissingFiles(): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) + val df = spark.read.format(format).load( new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + // Make sure all data files are deleted and can't be opened. + files.foreach(f => fs.delete(f, false)) assert(fs.delete(thirdPath, true)) + for (f <- files) { + intercept[FileNotFoundException](fs.open(f)) + } + checkAnswer(df, Seq(Row("0"), Row("1"))) } } From c338c8cf8253c037ecd4f39bbd58ed5a86581b37 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 12 Feb 2018 20:49:36 +0900 Subject: [PATCH 016/593] [SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs ## What changes were proposed in this pull request? This PR targets to explicitly specify supported types in Pandas UDFs. The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things. 1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see: ```python from pyspark.sql.functions import pandas_udf pudf = pandas_udf(lambda x: x, "binary") df = spark.createDataFrame([[bytearray(1)]]) df.select(pudf("_1")).show() ``` ``` ... TypeError: Unsupported type in conversion to Arrow: BinaryType ``` We can document this behaviour for its guide. 2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case. ```python from pyspark.sql.functions import pandas_udf, PandasUDFType foo = pandas_udf(lambda v: v.mean(), 'array', PandasUDFType.GROUPED_AGG) df = spark.range(100).selectExpr("id", "array(id) as value") df.groupBy("id").agg(foo("value")).show() ``` ``` ... NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG ``` 3. Since we can check the return type ahead, we can fail fast before actual execution. ```python # we can fail fast at this stage because we know the schema ahead pandas_udf(lambda x: x, BinaryType()) ``` ## How was this patch tested? Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added. Author: hyukjinkwon Closes #20531 from HyukjinKwon/pudf-cleanup. --- docs/sql-programming-guide.md | 4 +- python/pyspark/sql/tests.py | 130 +++++++++++------- python/pyspark/sql/types.py | 4 + python/pyspark/sql/udf.py | 36 +++-- python/pyspark/worker.py | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- 6 files changed, 111 insertions(+), 67 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index eab4030ee25d2..6174a93b68492 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) @@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p ### Supported SQL Types -Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`, `ArrayType` of `TimestampType`, and nested `StructType`. ### Setting Arrow Batch Size diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe89bd0685027..2af218a691026 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3790,10 +3790,10 @@ def foo(x): self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) def foo(x): return x - self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.returnType, DoubleType()) self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) @@ -3830,7 +3830,7 @@ def zero_with_type(): @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df - with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df @@ -3879,7 +3879,7 @@ def random_udf(v): return random_udf def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -3887,7 +3887,8 @@ def test_vectorized_udf_basic(self): col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) + col('id').cast('boolean').alias('bool'), + array(col('id')).alias('array_long')) f = lambda x: x str_f = pandas_udf(f, StringType()) int_f = pandas_udf(f, IntegerType()) @@ -3896,10 +3897,11 @@ def test_vectorized_udf_basic(self): double_f = pandas_udf(f, DoubleType()) decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) + array_long_f = pandas_udf(f, ArrayType(LongType())) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) + bool_f(col('bool')), array_long_f('array_long')) self.assertEquals(df.collect(), res.collect()) def test_register_nondeterministic_vectorized_udf_basic(self): @@ -4104,10 +4106,11 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.select(f(col('id'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): from pyspark.sql.functions import pandas_udf, col @@ -4142,13 +4145,18 @@ def test_vectorized_udf_varargs(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('map'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col @@ -4379,15 +4387,16 @@ def data(self): .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data + def test_supported_types(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + df = self.data.withColumn("arr", array(col("id"))) foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), StructType( [StructField('id', LongType()), StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), StructField('v1', DoubleType()), StructField('v2', LongType())]), PandasUDFType.GROUPED_MAP @@ -4490,17 +4499,15 @@ def test_datatype_string(self): def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUPED_MAP - ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.groupby('id').apply(foo).sort('id').toPandas() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf( + lambda pdf: pdf, + 'id long, v map', + PandasUDFType.GROUPED_MAP) def test_wrong_args(self): from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType @@ -4519,23 +4526,30 @@ def test_wrong_args(self): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) + df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), - PandasUDFType.SCALAR)) + pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType + from pyspark.sql.functions import pandas_udf, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.groupby('id').apply(f).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + schema = StructType( + [StructField("id", LongType(), True), + StructField("arr_ts", ArrayType(TimestampType()), True)]) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) # Regression test for SPARK-23314 def test_timestamp_dst(self): @@ -4614,23 +4628,32 @@ def weighted_mean(v, w): return weighted_mean def test_manual(self): + from pyspark.sql.functions import pandas_udf, array + df = self.data sum_udf = self.pandas_agg_sum_udf mean_udf = self.pandas_agg_mean_udf - - result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + mean_arr_udf = pandas_udf( + self.pandas_agg_mean_udf.func, + ArrayType(self.pandas_agg_mean_udf.returnType), + self.pandas_agg_mean_udf.evalType) + + result1 = df.groupby('id').agg( + sum_udf(df.v), + mean_udf(df.v), + mean_arr_udf(array(df.v))).sort('id') expected1 = self.spark.createDataFrame( - [[0, 245.0, 24.5], - [1, 255.0, 25.5], - [2, 265.0, 26.5], - [3, 275.0, 27.5], - [4, 285.0, 28.5], - [5, 295.0, 29.5], - [6, 305.0, 30.5], - [7, 315.0, 31.5], - [8, 325.0, 32.5], - [9, 335.0, 33.5]], - ['id', 'sum(v)', 'avg(v)']) + [[0, 245.0, 24.5, [24.5]], + [1, 255.0, 25.5, [25.5]], + [2, 265.0, 26.5, [26.5]], + [3, 275.0, 27.5, [27.5]], + [4, 285.0, 28.5, [28.5]], + [5, 295.0, 29.5, [29.5]], + [6, 305.0, 30.5, [30.5]], + [7, 315.0, 31.5, [31.5]], + [8, 325.0, 32.5, [32.5]], + [9, 335.0, 33.5, [33.5]]], + ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) @@ -4667,14 +4690,15 @@ def test_basic(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_unsupported_types(self): - from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.types import DoubleType, MapType from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return [v.mean(), v.std()] + pandas_udf( + lambda x: x, + ArrayType(ArrayType(TimestampType())), + PandasUDFType.GROUPED_AGG) with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f7141b4549e4e..e25941cd37595 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1638,6 +1638,8 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: + if type(dt.elementType) == TimestampType: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -1680,6 +1682,8 @@ def from_arrow_type(at): elif types.is_timestamp(at): spark_type = TimestampType() elif types.is_list(at): + if types.is_timestamp(at.value_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 08c6b9e521e82..e5b35fc60e167 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -23,7 +23,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string + _parse_datatype_string, to_arrow_type, to_arrow_schema __all__ = ["UDFRegistration"] @@ -112,15 +112,31 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ - and not isinstance(self._returnType_placeholder, StructType): - raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUPED_MAP") - elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ - and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): - raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with " - "PandasUDFType.GROUPED_AGG") + if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with scalar Pandas UDFs: %s is " + "not supported" % str(self._returnType_placeholder)) + elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_schema(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped map Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) + else: + raise TypeError("Invalid returnType for grouped map Pandas " + "UDFs: returnType must be a StructType.") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped aggregate Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 121b3dd1aeec9..89a3a92bc66d6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -116,7 +116,7 @@ def wrap_grouped_agg_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd result = f(*series) - return pd.Series(result) + return pd.Series([result]) return lambda *a: (wrapped(*a), arrow_return_type) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1e2501ee7757d..7835dbaa58439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,7 +1064,7 @@ object SQLConf { "for use with pyspark.sql.DataFrame.toPandas, and " + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + "The following data types are unsupported: " + - "MapType, ArrayType of TimestampType, and nested StructType.") + "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From caeb108e25e5bfb7cffcf09ef9abbb1abcfa355d Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 12 Feb 2018 22:05:27 +0800 Subject: [PATCH 017/593] [MINOR][TEST] spark.testing` No effect on the SparkFunSuite unit test ## What changes were proposed in this pull request? Currently, we use SBT and MAVN to spark unit test, are affected by the parameters of `spark.testing`. However, when using the IDE test tool, `spark.testing` support is not very good, sometimes need to be manually added to the beforeEach. example: HiveSparkSubmitSuite RPackageUtilsSuite SparkSubmitSuite. The PR unified `spark.testing` parameter extraction to SparkFunSuite, support IDE test tool, and the test code is more compact. ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20582 from heary-cao/sparktesting. --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 1 + .../test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala | 1 - .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 1 - .../spark/network/netty/NettyBlockTransferServiceSuite.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 1 - 5 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 3af9d82393bc4..31289026b0027 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -59,6 +59,7 @@ abstract class SparkFunSuite protected val enableAutoThreadAudit = true protected override def beforeAll(): Unit = { + System.setProperty("spark.testing", "true") if (enableAutoThreadAudit) { doThreadPreAudit() } diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 32dd3ecc2f027..ef947eb074647 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -66,7 +66,6 @@ class RPackageUtilsSuite override def beforeEach(): Unit = { super.beforeEach() - System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 27dd435332348..803a38d77fb82 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -107,7 +107,6 @@ class SparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } // scalastyle:off println diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index f7bc3725d7278..78423ee68a0ec 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,6 +80,7 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests + // if `spark.testing` is true, // the default value for `spark.port.maxRetries` is 100 under test actualPort should be <= (expectedPort + 100) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10204f4694663..2d31781132edc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -50,7 +50,6 @@ class HiveSparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } test("temporary Hive UDF: define a UDF and use it") { From 0e2c266de7189473177f45aa68ea6a45c7e47ec3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 22:07:59 +0800 Subject: [PATCH 018/593] [SPARK-22977][SQL] fix web UI SQL tab for CTAS ## What changes were proposed in this pull request? This is a regression in Spark 2.3. In Spark 2.2, we have a fragile UI support for SQL data writing commands. We only track the input query plan of `FileFormatWriter` and display its metrics. This is not ideal because we don't know who triggered the writing(can be table insertion, CTAS, etc.), but it's still useful to see the metrics of the input query. In Spark 2.3, we introduced a new mechanism: `DataWritigCommand`, to fix the UI issue entirely. Now these writing commands have real children, and we don't need to hack into the `FileFormatWriter` for the UI. This also helps with `explain`, now `explain` can show the physical plan of the input query, while in 2.2 the physical writing plan is simply `ExecutedCommandExec` and it has no child. However there is a regression in CTAS. CTAS commands don't extend `DataWritigCommand`, and we don't have the UI hack in `FileFormatWriter` anymore, so the UI for CTAS is just an empty node. See https://issues.apache.org/jira/browse/SPARK-22977 for more information about this UI issue. To fix it, we should apply the `DataWritigCommand` mechanism to CTAS commands. TODO: In the future, we should refactor this part and create some physical layer code pieces for data writing, and reuse them in different writing commands. We should have different logical nodes for different operators, even some of them share some same logic, e.g. CTAS, CREATE TABLE, INSERT TABLE. Internally we can share the same physical logic. ## How was this patch tested? manually tested. For data source table 1 For hive table 2 Author: Wenchen Fan Closes #20521 from cloud-fan/UI. --- .../command/createDataSourceTables.scala | 21 +++---- .../execution/datasources/DataSource.scala | 44 +++++++++++++-- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../CreateHiveTableAsSelectCommand.scala | 55 ++++++++++--------- .../sql/hive/execution/HiveExplainSuite.scala | 26 --------- 6 files changed, 80 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 306f43dc4214a..e9747769dfcfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,7 +21,9 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -136,12 +138,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, - query: LogicalPlan) - extends RunnableCommand { - - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + query: LogicalPlan, + outputColumns: Seq[Attribute]) + extends DataWritingCommand { - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) assert(table.provider.isDefined) @@ -163,7 +164,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) + sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -173,7 +174,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) + sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of @@ -198,10 +199,10 @@ case class CreateDataSourceTableAsSelectCommand( session: SparkSession, table: CatalogTable, tableLocation: Option[URI], - data: LogicalPlan, + physicalPlan: SparkPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { - // Create the relation based on the input logical plan: `data`. + // Create the relation based on the input logical plan: `query`. val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, @@ -212,7 +213,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query) + dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 25e1210504273..6e1b5727e3fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -31,8 +31,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -435,10 +437,11 @@ case class DataSource( } /** - * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. + * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]]. + * The returned command is unresolved and need to be analyzed. */ private def planForWritingFileFormat( - format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { + format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -482,9 +485,24 @@ case class DataSource( /** * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. + * + * @param mode The save mode for this writing. + * @param data The input query plan that produces the data to be written. Note that this plan + * is analyzed and optimized. + * @param outputColumns The original output columns of the input query plan. The optimizer may not + * preserve the output column's names' case, so we need this parameter + * instead of `data.output`. + * @param physicalPlan The physical plan of the input query plan. We should run the writing + * command with this physical plan instead of creating a new physical plan, + * so that the metrics can be correctly linked to the given physical plan and + * shown in the web UI. */ - def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + def writeAndRead( + mode: SaveMode, + data: LogicalPlan, + outputColumns: Seq[Attribute], + physicalPlan: SparkPlan): BaseRelation = { + if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -493,9 +511,23 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd + val cmd = planForWritingFileFormat(format, mode, data) + val resolvedPartCols = cmd.partitionColumns.map { col => + // The partition columns created in `planForWritingFileFormat` should always be + // `UnresolvedAttribute` with a single name part. + assert(col.isInstanceOf[UnresolvedAttribute]) + val unresolved = col.asInstanceOf[UnresolvedAttribute] + assert(unresolved.nameParts.length == 1) + val name = unresolved.nameParts.head + outputColumns.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d94c5bbccdd84..3f41612c08065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ab857b9055720..8df05cbb20361 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -157,7 +157,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 65e8b4e3c725c..1e801fe1845c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand /** @@ -36,15 +37,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, + outputColumns: Seq[Attribute], mode: SaveMode) - extends RunnableCommand { + extends DataWritingCommand { private val tableIdentifier = tableDesc.identifier - override def innerChildren: Seq[LogicalPlan] = Seq(query) - - override def run(sparkSession: SparkSession): Seq[Row] = { - if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + if (catalog.tableExists(tableIdentifier)) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -56,34 +57,36 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = false, - ifPartitionNotExists = false)).toRdd + InsertIntoHiveTable( + tableDesc, + Map.empty, + query, + overwrite = false, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - sparkSession.sessionState.catalog.createTable( - tableDesc.copy(schema = query.schema), ignoreIfExists = false) + catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) try { - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = true, - ifPartitionNotExists = false)).toRdd + // Read back the metadata of the table which was created just now. + val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) + // For CTAS, there is no static partition values to insert. + val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + createdTableMeta, + partition, + query, + overwrite = true, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. - sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, - purge = false) + catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false) throw e } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index f84d188075b72..5d56f89c2271c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -128,32 +128,6 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { - withTempView("jt") { - val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() - spark.read.json(ds).createOrReplaceTempView("jt") - val outputs = sql( - s""" - |EXPLAIN EXTENDED - |CREATE TABLE t1 - |AS - |SELECT * FROM jt - """.stripMargin).collect().map(_.mkString).mkString - - val shouldContain = - "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: - "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: - "CreateHiveTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil - for (key <- shouldContain) { - assert(outputs.contains(key), s"$key doesn't exist in result") - } - - val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should contain SubqueryAlias since the query should not be optimized") - } - } - test("explain output of physical plan should contain proper codegen stage ID") { checkKeywordsExist(sql( """ From 4a4dd4f36f65410ef5c87f7b61a960373f044e61 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 12 Feb 2018 08:49:45 -0600 Subject: [PATCH 019/593] [SPARK-23391][CORE] It may lead to overflow for some integer multiplication ## What changes were proposed in this pull request? In the `getBlockData`,`blockId.reduceId` is the `Int` type, when it is greater than 2^28, `blockId.reduceId*8` will overflow In the `decompress0`, `len` and `unitSize` are Int type, so `len * unitSize` may lead to overflow ## How was this patch tested? N/A Author: liuxian Closes #20581 from 10110346/overflow2. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../execution/columnar/compression/compressionSchemes.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d88b25cc7e258..d3f1c7ec1bbee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -202,13 +202,13 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8) + channel.position(blockId.reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { val offset = in.readLong() val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8 + 16 + val expectedPosition = blockId.reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 79dcf3a6105ce..00a1d54b41709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -116,7 +116,7 @@ private[columnar] case object PassThrough extends CompressionScheme { while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos - assert(len * unitSize < Int.MaxValue) + assert(len * unitSize.toLong < Int.MaxValue) putFunction(columnVector, pos, bufferPos, len) bufferPos += len * unitSize pos += len From 5bb11411aec18b8d623e54caba5397d7cb8e89f0 Mon Sep 17 00:00:00 2001 From: James Thompson Date: Mon, 12 Feb 2018 11:34:56 -0800 Subject: [PATCH 020/593] [SPARK-23388][SQL] Support for Parquet Binary DecimalType in VectorizedColumnReader ## What changes were proposed in this pull request? Re-add support for parquet binary DecimalType in VectorizedColumnReader ## How was this patch tested? Existing test suite Author: James Thompson Closes #20580 from jamesthomp/jt/add-back-binary-decimal. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c120863152a96..47dd625f4b154 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -444,7 +444,8 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType + || DecimalType.isByteArrayDecimalType(column.dataType())) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { if (!shouldConvertTimestamps()) { From 0c66fe4f22f8af4932893134bb0fd56f00fabeae Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 12 Feb 2018 12:20:29 -0800 Subject: [PATCH 021/593] [SPARK-22002][SQL][FOLLOWUP][TEST] Add a test to check if the original schema doesn't have metadata. ## What changes were proposed in this pull request? This is a follow-up pr of #19231 which modified the behavior to remove metadata from JDBC table schema. This pr adds a test to check if the schema doesn't have metadata. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20585 from ueshin/issues/SPARK-22002/fup1. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cb2df0ac54f4c..5238adce4a699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1168,4 +1168,26 @@ class JDBCSuite extends SparkFunSuite val df3 = sql("SELECT * FROM test_sessionInitStatement") assert(df3.collect() === Array(Row(21519, 1234))) } + + test("jdbc data source shouldn't have unnecessary metadata in its schema") { + val schema = StructType(Seq( + StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("DbTaBle", "TEST.PEOPLE") + .load() + assert(df.schema === schema) + + withTempView("people_view") { + sql( + s""" + |CREATE TEMPORARY VIEW people_view + |USING org.apache.spark.sql.jdbc + |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + assert(sql("select * from people_view").schema === schema) + } + } } From fba01b9a65e5d9438d35da0bd807c179ba741911 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 14:58:31 -0800 Subject: [PATCH 022/593] [SPARK-23378][SQL] move setCurrentDatabase from HiveExternalCatalog to HiveClientImpl ## What changes were proposed in this pull request? This removes the special case that `alterPartitions` call from `HiveExternalCatalog` can reset the current database in the hive client as a side effect. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20564 from liufengdb/move. --- .../spark/sql/hive/HiveExternalCatalog.scala | 5 ---- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 3b8a8ca301c27..1ee1d57b8ebe1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1107,11 +1107,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - client.setCurrentDatabase(db) client.alterPartitions(db, table, withStatsProps) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6c0f4144992ae..c223f51b1be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -291,14 +291,18 @@ private[hive] class HiveClientImpl( state.err = stream } - override def setCurrentDatabase(databaseName: String): Unit = withHiveState { - if (databaseExists(databaseName)) { - state.setCurrentDatabase(databaseName) + private def setCurrentDatabaseRaw(db: String): Unit = { + if (databaseExists(db)) { + state.setCurrentDatabase(db) } else { - throw new NoSuchDatabaseException(databaseName) + throw new NoSuchDatabaseException(db) } } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + setCurrentDatabaseRaw(databaseName) + } + override def createDatabase( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { @@ -598,8 +602,18 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(userName)) - shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + val original = state.getCurrentDatabase + try { + setCurrentDatabaseRaw(db) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) + shim.alterPartitions(client, table, newParts.map { toHivePartition(_, hiveTable) }.asJava) + } finally { + state.setCurrentDatabase(original) + } } /** From 6cb59708c70c03696c772fbb5d158eed57fe67d4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Feb 2018 15:26:37 -0800 Subject: [PATCH 023/593] [SPARK-23313][DOC] Add a migration guide for ORC ## What changes were proposed in this pull request? This PR adds a migration guide documentation for ORC. ![orc-guide](https://user-images.githubusercontent.com/9700541/36123859-ec165cae-1002-11e8-90b7-7313be7a81a5.png) ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20484 from dongjoon-hyun/SPARK-23313. --- docs/sql-programming-guide.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6174a93b68492..0f9f01e18682f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1776,6 +1776,35 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 + - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. + + - New configurations + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
+ + - Changed configurations + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
+ - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. From 4104b68e958cd13975567a96541dac7cccd8195c Mon Sep 17 00:00:00 2001 From: sychen Date: Mon, 12 Feb 2018 16:00:47 -0800 Subject: [PATCH 024/593] [SPARK-23230][SQL] When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error. We should take the default type of textfile and sequencefile both as org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe. ``` set hive.default.fileformat=orc; create table tbl( i string ) stored as textfile; desc formatted tbl; Serde Library org.apache.hadoop.hive.ql.io.orc.OrcSerde InputFormat org.apache.hadoop.mapred.TextInputFormat OutputFormat org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat ``` Author: sychen Closes #20406 from cxzl25/default_serde. --- .../apache/spark/sql/internal/HiveSerDe.scala | 6 ++++-- .../sql/hive/execution/HiveSerDeSuite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index dac463641cfab..eca612f06f9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -31,7 +31,8 @@ object HiveSerDe { "sequencefile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "rcfile" -> HiveSerDe( @@ -54,7 +55,8 @@ object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "avro" -> HiveSerDe( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 1c9f00141ae1d..d7752e987cb4b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -100,6 +100,25 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS textfile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS sequencefile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.SequenceFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } } test("create hive serde table with new syntax - basic") { From c1bcef876c1415e39e624cfbca9c9bdeae24cbb9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 13 Feb 2018 11:40:34 +0800 Subject: [PATCH 025/593] [SPARK-23323][SQL] Support commit coordinator for DataSourceV2 writes ## What changes were proposed in this pull request? DataSourceV2 batch writes should use the output commit coordinator if it is required by the data source. This adds a new method, `DataWriterFactory#useCommitCoordinator`, that determines whether the coordinator will be used. If the write factory returns true, `WriteToDataSourceV2` will use the coordinator for batch writes. ## How was this patch tested? This relies on existing write tests, which now use the commit coordinator. Author: Ryan Blue Closes #20490 from rdblue/SPARK-23323-add-commit-coordinator. --- .../sources/v2/writer/DataSourceWriter.java | 19 +++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 41 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index e3f682bf96a66..0a0fd8db58035 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -63,6 +63,16 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for + * each task commits. + * + * @return true if commit coordinator should be used, false otherwise. + */ + default boolean useCommitCoordinator() { + return true; + } + /** * Handles a commit message on receiving from a successful data writer. * @@ -79,10 +89,11 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * - * Note that, one partition may have multiple committed data writers because of speculative tasks. - * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data - * writer can commit, or have a way to clean up the data of already-committed writers. + * Note that speculative execution may cause multiple tasks to run for a partition. By default, + * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple + * attempts may have committed successfully and one successful commit message per task will be + * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index eefbcf4c0e087..535e7962d7439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -53,6 +54,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } + val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -73,7 +75,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e DataWritingSparkTask.runContinuous(writeTask, context, iter) case _ => (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter) + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) } sparkContext.runJob( @@ -116,21 +118,44 @@ object DataWritingSparkTask extends Logging { def run( writeTask: DataWriterFactory[InternalRow], context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean): WriterCommitMessage = { + val stageId = context.stageId() + val partId = context.partitionId() + val attemptId = context.attemptNumber() + val dataWriter = writeTask.createDataWriter(partId, attemptId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { iter.foreach(dataWriter.write) - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + if (commitAuthorized) { + logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + dataWriter.commit() + } else { + val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + logInfo(message) + // throwing CommitDeniedException will trigger the catch block for abort + throw new CommitDeniedException(message, stageId, partId, attemptId) + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + msg + })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for partition ${context.partitionId()} is aborting.") + logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") + logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } From ed4e78bd606e7defc2cd01a5c2e9b47954baa424 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 20:57:26 -0800 Subject: [PATCH 026/593] [SPARK-23379][SQL] skip when setting the same current database in HiveClientImpl ## What changes were proposed in this pull request? If the target database name is as same as the current database, we should be able to skip one metastore access. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20565 from liufengdb/remove-redundant. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c223f51b1be75..146fa54a1bce4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -292,10 +292,12 @@ private[hive] class HiveClientImpl( } private def setCurrentDatabaseRaw(db: String): Unit = { - if (databaseExists(db)) { - state.setCurrentDatabase(db) - } else { - throw new NoSuchDatabaseException(db) + if (state.getCurrentDatabase != db) { + if (databaseExists(db)) { + state.setCurrentDatabase(db) + } else { + throw new NoSuchDatabaseException(db) + } } } From f17b936f0ddb7d46d1349bd42f9a64c84c06e48d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 21:12:22 -0800 Subject: [PATCH 027/593] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- Relation AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- Relation AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == Relation AdvancedDataSourceV2[j#1] == Physical Plan == *(1) Scan AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) Scan JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) Scan FakeDataSourceV2$[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20477 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +--- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../v2/DataSourceV2QueryPlan.scala | 96 +++++++++++++++++++ .../datasources/v2/DataSourceV2Relation.scala | 26 +++-- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 +++-- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 +-- 15 files changed, 157 insertions(+), 127 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..72ee0c551ec3d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,8 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..d34458ac81014 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 02c87643568bd..cb09cce75ff6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,7 +117,8 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..984b6510f2dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,11 +189,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() + val ds = cls.newInstance().asInstanceOf[DataSourceV2] val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -221,7 +219,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala new file mode 100644 index 0000000000000..1e0d088f3a57c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A base class for data source v2 related query plan(both logical and physical). It defines the + * equals/hashCode methods, and provides a string representation of the query plan, according to + * some common information. + */ +trait DataSourceV2QueryPlan { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + /** + * The metadata of this data source query plan that can be used for equality check. + */ + private def metadata: Seq[Any] = Seq(output, source.getClass, filters) + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) + }, " (", ", ", ")") + } else { + "" + } + + s"${source.getClass.getSimpleName}$outputStr$entriesStr" + } + + private def redact(text: String): String = { + Utils.redact(SQLConf.get.stringRedationPattern, text) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..cd97e0cab6b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,15 +20,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + reader: DataSourceReader, + override val isStreaming: Boolean) + extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + override def simpleString: String = { + val streamingHeader = if (isStreaming) "Streaming " else "" + s"${streamingHeader}Relation $metadataString" + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -41,18 +49,8 @@ case class DataSourceV2Relation( } } -/** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. - */ -class StreamingDataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { - override def isStreaming: Boolean = true -} - object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..c99d535efcf81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,11 +37,14 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = s"Scan $metadataString" + override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..fb61e6f32b1f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..4cfdd50e8f46b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + case FilterAndProject(fields, condition, r: DataSourceV2Relation) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = reader match { + val stayUpFilters: Seq[Expression] = r.reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..84564b6639ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -90,6 +92,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -405,12 +408,15 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) + reader.setOffsetRange(toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + Some(reader -> new DataSourceV2Relation( + reader.readSchema().toAttributes, + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), + reader, + isStreaming = true)) case _ => None } } @@ -500,3 +506,5 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..f87d57d0b3209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(ds, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,8 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => + r.reader.asInstanceOf[ContinuousReader] }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..70eb9f0ac66d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..254394685857b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case d: DataSourceV2Relation => d.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..9ee9aaf87f87c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 407f67249639709c40c46917700ed6dd736daa7d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 15:05:13 +0900 Subject: [PATCH 028/593] [SPARK-20090][FOLLOW-UP] Revert the deprecation of `names` in PySpark ## What changes were proposed in this pull request? Deprecating the field `name` in PySpark is not expected. This PR is to revert the change. ## How was this patch tested? N/A Author: gatorsmile Closes #20595 from gatorsmile/removeDeprecate. --- python/pyspark/sql/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e25941cd37595..cd857402db8f7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -455,9 +455,6 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. - .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead - to get a list of field names. - >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) From 9dae715168a8e72e318ab231c34a1069bfa342a6 Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Tue, 13 Feb 2018 06:20:34 -0600 Subject: [PATCH 029/593] [SPARK-23318][ML] FP-growth: WARN FPGrowth: Input data is not cached ## What changes were proposed in this pull request? Cache the RDD of items in ml.FPGrowth before passing it to mllib.FPGrowth. Cache only when the user did not cache the input dataset of transactions. This fixes the warning about uncached data emerging from mllib.FPGrowth. ## How was this patch tested? Manually: 1. Run ml.FPGrowthExample - warning is there 2. Apply the fix 3. Run ml.FPGrowthExample again - no warning anymore Author: Arseniy Tashoyan Closes #20578 from tashoyan/SPARK-23318. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index aa7871d6ff29d..3d041fc80eb7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel /** * Common params for FPGrowth and FPGrowthModel @@ -158,18 +159,30 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val data = dataset.select($(itemsCol)) - val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } + + if (handlePersistence) { + items.persist(StorageLevel.MEMORY_AND_DISK) + } + val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + + if (handlePersistence) { + items.unpersist() + } + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) } From 300c40f50ab4258d697f06a814d1491dc875c847 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 06:23:10 -0600 Subject: [PATCH 030/593] [SPARK-23384][WEB-UI] When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. ## What changes were proposed in this pull request? When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. It is a bug. fix before: ![1](https://user-images.githubusercontent.com/26266482/36070635-264d7cf0-0f3a-11e8-8426-14135ffedb16.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/36070651-8ec3800e-0f3a-11e8-991c-6122cc9539fe.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20573 from guoxiaolongzte/SPARK-23384. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 5d62a7d8bebb4..6fc12d721e6f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - + ++ +
    @@ -65,7 +66,6 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++
    ++ - ++ ++ } else if (requestedIncomplete) { From 116c581d2658571d38f8b9b27a516ef517170589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 06:54:15 -0800 Subject: [PATCH 031/593] [SPARK-20659][CORE] Removing sc.getExecutorStorageStatus and making StorageStatus private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR StorageStatus is made to private and simplified a bit moreover SparkContext.getExecutorStorageStatus method is removed. The reason of keeping StorageStatus is that it is usage from SparkContext.getRDDStorageInfo. Instead of the method SparkContext.getExecutorStorageStatus executor infos are extended with additional memory metrics such as usedOnHeapStorageMemory, usedOffHeapStorageMemory, totalOnHeapStorageMemory, totalOffHeapStorageMemory. ## How was this patch tested? By running existing unit tests. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20546 from attilapiros/SPARK-20659. --- .../org/apache/spark/SparkExecutorInfo.java | 4 + .../scala/org/apache/spark/SparkContext.scala | 19 +- .../org/apache/spark/SparkStatusTracker.scala | 9 +- .../org/apache/spark/StatusAPIImpl.scala | 6 +- .../apache/spark/storage/StorageUtils.scala | 119 +--------- .../org/apache/spark/DistributedSuite.scala | 7 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../apache/spark/storage/StorageSuite.scala | 219 ------------------ project/MimaExcludes.scala | 14 ++ .../spark/repl/SingletonReplSuite.scala | 6 +- 10 files changed, 44 insertions(+), 361 deletions(-) diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java index dc3e826475987..2b93385adf103 100644 --- a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -30,4 +30,8 @@ public interface SparkExecutorInfo extends Serializable { int port(); long cacheSize(); int numRunningTasks(); + long usedOnHeapStorageMemory(); + long usedOffHeapStorageMemory(); + long totalOnHeapStorageMemory(); + long totalOffHeapStorageMemory(); } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3828d4f703247..c4f74c4f1f9c2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1715,7 +1715,13 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + val rddStorageInfo = statusStore.asOption(statusStore.rdd(rddId)) + rddInfo.numCachedPartitions = rddStorageInfo.map(_.numCachedPartitions).getOrElse(0) + rddInfo.memSize = rddStorageInfo.map(_.memoryUsed).getOrElse(0L) + rddInfo.diskSize = rddStorageInfo.map(_.diskUsed).getOrElse(0L) + } rddInfos.filter(_.isCached) } @@ -1726,17 +1732,6 @@ class SparkContext(config: SparkConf) extends Logging { */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - @deprecated("This method may change or be removed in a future release.", "2.2.0") - def getExecutorStorageStatus: Array[StorageStatus] = { - assertNotStopped() - env.blockManager.master.getStorageStatus - } - /** * :: DeveloperApi :: * Return pools for fair scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 70865cb58c571..815237eba0174 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -97,7 +97,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore } /** - * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. + * Returns information of all known executors, including host, port, cacheSize, numRunningTasks + * and memory metrics. */ def getExecutorInfos: Array[SparkExecutorInfo] = { store.executorList(true).map { exec => @@ -113,7 +114,11 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore host, port, cachedMem, - exec.activeTasks) + exec.activeTasks, + exec.memoryMetrics.map(_.usedOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.usedOnHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOnHeapStorageMemory).getOrElse(0L)) }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index c1f24a6377788..6a888c1e9e772 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -38,5 +38,9 @@ private class SparkExecutorInfoImpl( val host: String, val port: Int, val cacheSize: Long, - val numRunningTasks: Int) + val numRunningTasks: Int, + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) extends SparkExecutorInfo diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e9694fdbca2de..adc406bb1c441 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -24,19 +24,15 @@ import scala.collection.mutable import sun.nio.ch.DirectBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging /** - * :: DeveloperApi :: * Storage information for each BlockManager. * * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this * class cannot mutate the source of the information. Accesses are not thread-safe. */ -@DeveloperApi -@deprecated("This class may be removed or made private in a future release.", "2.2.0") -class StorageStatus( +private[spark] class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, val maxOnHeapMem: Option[Long], @@ -44,9 +40,6 @@ class StorageStatus( /** * Internal representation of the blocks stored in this block manager. - * - * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. - * These collections should only be mutated through the add/update/removeBlock methods. */ private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] @@ -87,9 +80,6 @@ class StorageStatus( */ def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - /** Return the blocks that belong to the given RDD stored in this block manager. */ - def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty) - /** Add the given block to this storage status. If it already exists, overwrite it. */ private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { updateStorageInfo(blockId, blockStatus) @@ -101,46 +91,6 @@ class StorageStatus( } } - /** Update the given block in this storage status. If it doesn't already exist, add it. */ - private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { - addBlock(blockId, blockStatus) - } - - /** Remove the given block from this storage status. */ - private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { - updateStorageInfo(blockId, BlockStatus.empty) - blockId match { - case RDDBlockId(rddId, _) => - // Actually remove the block, if it exists - if (_rddBlocks.contains(rddId)) { - val removed = _rddBlocks(rddId).remove(blockId) - // If the given RDD has no more blocks left, remove the RDD - if (_rddBlocks(rddId).isEmpty) { - _rddBlocks.remove(rddId) - } - removed - } else { - None - } - case _ => - _nonRddBlocks.remove(blockId) - } - } - - /** - * Return whether the given block is stored in this block manager in O(1) time. - * - * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. - */ - def containsBlock(blockId: BlockId): Boolean = { - blockId match { - case RDDBlockId(rddId, _) => - _rddBlocks.get(rddId).exists(_.contains(blockId)) - case _ => - _nonRddBlocks.contains(blockId) - } - } - /** * Return the given block stored in this block manager in O(1) time. * @@ -155,37 +105,12 @@ class StorageStatus( } } - /** - * Return the number of blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.blocks.size`, which is O(blocks) time. - */ - def numBlocks: Int = _nonRddBlocks.size + numRddBlocks - - /** - * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. - */ - def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - - /** - * Return the number of blocks that belong to the given RDD in O(1) time. - * - * @note This is much faster than `this.rddBlocksById(rddId).size`, which is - * O(blocks in this RDD) time. - */ - def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) - /** Return the max memory can be used by this block manager. */ def maxMem: Long = maxMemory /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed - /** Return the memory used by caching RDDs */ - def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) - /** Return the memory used by this block manager. */ def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) @@ -220,15 +145,9 @@ class StorageStatus( /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) - /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) - /** * Update the relevant storage info, taking into account any existing status for this block. */ @@ -295,40 +214,4 @@ private[spark] object StorageUtils extends Logging { cleaner.clean() } } - - /** - * Update the given list of RDDInfo with the given list of storage statuses. - * This method overwrites the old values stored in the RDDInfo's. - */ - def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { - rddInfos.foreach { rddInfo => - val rddId = rddInfo.id - // Assume all blocks belonging to the same RDD have the same storage level - val storageLevel = statuses - .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) - val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum - val memSize = statuses.map(_.memUsedByRdd(rddId)).sum - val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - - rddInfo.storageLevel = storageLevel - rddInfo.numCachedPartitions = numCachedPartitions - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - } - } - - /** - * Return a mapping from block ID to its locations for each block that belongs to the given RDD. - */ - def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] - statuses.foreach { status => - status.rddBlocksById(rddId).foreach { case (bid, _) => - val location = status.blockManagerId.hostPort - blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location - } - } - blockLocations - } - } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index e09d5f59817b9..28ea0c6f0bdba 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -160,11 +160,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) - assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === - storageLevel.replication * data.getNumPartitions) - assert(cachedData.count === 1000) - assert(cachedData.count === 1000) - + assert(sc.getRDDStorageInfo.filter(_.id == cachedData.id).map(_.numCachedPartitions).sum === + data.getNumPartitions) // Get all the locations of the first partition and try to fetch the partitions // from those locations. val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index bf7480d79f8a1..c21ee7d26f8ca 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -610,7 +610,7 @@ class StandaloneDynamicAllocationSuite * we submit a request to kill them. This must be called before each kill request. */ private def syncExecutors(sc: SparkContext): Unit = { - val driverExecutors = sc.getExecutorStorageStatus + val driverExecutors = sc.env.blockManager.master.getStorageStatus .map(_.blockManagerId.executorId) .filter { _ != SparkContext.DRIVER_IDENTIFIER} val masterExecutors = getExecutorIds(sc) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index da198f946fd64..ca352387055f4 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -51,27 +51,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsed === 60L) } - test("storage status update non-RDD blocks") { - val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) - assert(status.blocks.size === 3) - assert(status.memUsed === 160L) - assert(status.memRemaining === 840L) - assert(status.diskUsed === 140L) - } - - test("storage status remove non-RDD blocks") { - val status = storageStatus1 - status.removeBlock(TestBlockId("foo")) - status.removeBlock(TestBlockId("faa")) - assert(status.blocks.size === 1) - assert(status.blocks.contains(TestBlockId("fee"))) - assert(status.memUsed === 10L) - assert(status.memRemaining === 990L) - assert(status.diskUsed === 20L) - } - // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) @@ -95,85 +74,6 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocks.contains(RDDBlockId(2, 2))) assert(status.rddBlocks.contains(RDDBlockId(2, 3))) assert(status.rddBlocks.contains(RDDBlockId(2, 4))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1))) - assert(status.rddBlocksById(2).size === 3) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 30L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 80L) - assert(status.rddStorageLevel(0) === Some(memAndDisk)) - assert(status.rddStorageLevel(1) === Some(memAndDisk)) - assert(status.rddStorageLevel(2) === Some(memAndDisk)) - - // Verify default values for RDDs that don't exist - assert(status.rddBlocksById(10).isEmpty) - assert(status.memUsedByRdd(10) === 0L) - assert(status.diskUsedByRdd(10) === 0L) - assert(status.rddStorageLevel(10) === None) - } - - test("storage status update RDD blocks") { - val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) - assert(status.blocks.size === 7) - assert(status.rddBlocks.size === 5) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(2).size === 3) - assert(status.memUsedByRdd(0) === 0L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 20L) - assert(status.diskUsedByRdd(0) === 0L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 1060L) - } - - test("storage status remove RDD blocks") { - val status = storageStatus2 - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(1, 1)) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 4)) - assert(status.blocks.size === 3) - assert(status.rddBlocks.size === 2) - assert(status.rddBlocks.contains(RDDBlockId(0, 0))) - assert(status.rddBlocks.contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 0) - assert(status.rddBlocksById(2).size === 1) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 0L) - assert(status.memUsedByRdd(2) === 10L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 0L) - assert(status.diskUsedByRdd(2) === 20L) - } - - test("storage status containsBlock") { - val status = storageStatus2 - // blocks that actually exist - assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan"))) - assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man"))) - assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0))) - assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1))) - assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2))) - assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3))) - assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4))) - // blocks that don't exist - assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan"))) - assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0))) } test("storage status getBlock") { @@ -191,40 +91,6 @@ class StorageSuite extends SparkFunSuite { assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0))) } - test("storage status num[Rdd]Blocks") { - val status = storageStatus2 - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(100).size === status.numRddBlocksById(100)) - status.removeBlock(RDDBlockId(4, 0)) - status.removeBlock(RDDBlockId(10, 10)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - // remove a block that doesn't exist - status.removeBlock(RDDBlockId(1000, 999)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000)) - } test("storage status memUsed, diskUsed, externalBlockStoreUsed") { val status = storageStatus2 @@ -237,17 +103,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -273,65 +128,6 @@ class StorageSuite extends SparkFunSuite { Seq(info0, info1) } - test("StorageUtils.updateRddInfo") { - val storageStatuses = stockStorageStatuses - val rddInfos = stockRDDInfos - StorageUtils.updateRddInfo(rddInfos, storageStatuses) - assert(rddInfos(0).storageLevel === memAndDisk) - assert(rddInfos(0).numCachedPartitions === 5) - assert(rddInfos(0).memSize === 5L) - assert(rddInfos(0).diskSize === 10L) - assert(rddInfos(0).externalBlockStoreSize === 0L) - assert(rddInfos(1).storageLevel === memAndDisk) - assert(rddInfos(1).numCachedPartitions === 3) - assert(rddInfos(1).memSize === 3L) - assert(rddInfos(1).diskSize === 6L) - assert(rddInfos(1).externalBlockStoreSize === 0L) - } - - test("StorageUtils.getRddBlockLocations") { - val storageStatuses = stockStorageStatuses - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0.contains(RDDBlockId(0, 0))) - assert(blockLocations0.contains(RDDBlockId(0, 1))) - assert(blockLocations0.contains(RDDBlockId(0, 2))) - assert(blockLocations0.contains(RDDBlockId(0, 3))) - assert(blockLocations0.contains(RDDBlockId(0, 4))) - assert(blockLocations1.contains(RDDBlockId(1, 0))) - assert(blockLocations1.contains(RDDBlockId(1, 1))) - assert(blockLocations1.contains(RDDBlockId(1, 2))) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - - test("StorageUtils.getRddBlockLocations with multiple locations") { - val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - private val offheap = StorageLevel.OFF_HEAP // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap // and offheap blocks @@ -373,21 +169,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) - assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) - - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } private def storageStatus4: StorageStatus = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d35c50e1d00fe..381f7b5be1ddf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,20 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20659] Remove StorageStatus, or make it private + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") ) // Exclude rules for 2.3.x diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index ec3d790255ad3..d49e0fd85229f 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -350,7 +350,7 @@ class SingletonReplSuite extends SparkFunSuite { """ |val timeout = 60000 // 60 seconds |val start = System.currentTimeMillis - |while(sc.getExecutorStorageStatus.size != 3 && + |while(sc.statusTracker.getExecutorInfos.size != 3 && | (System.currentTimeMillis - start) < timeout) { | Thread.sleep(10) |} @@ -361,11 +361,11 @@ class SingletonReplSuite extends SparkFunSuite { |case class Foo(i: Int) |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) |ret.count() - |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + |val res = sc.getRDDStorageInfo.filter(_.id == ret.id).map(_.numCachedPartitions).sum """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res: Int = 20", output) + assertContains("res: Int = 10", output) } test("should clone and clean line object in ClosureCleaner") { From d6e1958a2472898e60bd013902c2f35111596e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 09:54:52 -0600 Subject: [PATCH 032/593] [SPARK-23189][CORE][WEB UI] Reflect stage level blacklisting on executor tab MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The purpose of this PR to reflect the stage level blacklisting on the executor tab for the currently active stages. After this change in the executor tab at the Status column one of the following label will be: - "Blacklisted" when the executor is blacklisted application level (old flag) - "Dead" when the executor is not Blacklisted and not Active - "Blacklisted in Stages: [...]" when the executor is Active but the there are active blacklisted stages for the executor. Within the [] coma separated active stageIDs are listed. - "Active" when the executor is Active and there is no active blacklisted stages for the executor ## How was this patch tested? Both with unit tests and manually. #### Manual test Spark was started as: ```bash bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" ``` And the job was: ```scala import org.apache.spark.SparkEnv val pairs = sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt == 0) throw new RuntimeException("Bad executor") else { Thread.sleep(10) (x % 10, x) } } val all = pairs.cogroup(pairs) all.collect() ``` UI screenshots about the running: - One executor is blacklisted in the two stages: ![One executor is blacklisted in two stages](https://issues.apache.org/jira/secure/attachment/12908314/multiple_stages_1.png) - One stage completes the other one is still running: ![One stage completes the other is still running](https://issues.apache.org/jira/secure/attachment/12908315/multiple_stages_2.png) - Both stages are completed: ![Both stages are completed](https://issues.apache.org/jira/secure/attachment/12908316/multiple_stages_3.png) ### Unit tests In AppStatusListenerSuite.scala both the node blacklisting for a stage and the executor blacklisting for stage are tested. Author: “attilapiros” Closes #20408 from attilapiros/SPARK-23189. --- .../apache/spark/ui/static/executorspage.js | 21 +++++--- .../spark/status/AppStatusListener.scala | 49 ++++++++++++++----- .../org/apache/spark/status/LiveEntity.scala | 7 ++- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../executor_list_json_expectation.json | 3 +- .../executor_memory_usage_expectation.json | 15 ++++-- ...xecutor_node_blacklisting_expectation.json | 15 ++++-- ...acklisting_unblacklisting_expectation.json | 15 ++++-- .../spark/status/AppStatusListenerSuite.scala | 21 ++++++++ 9 files changed, 113 insertions(+), 36 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index d430d8c5fb35a..6717af3ac4daf 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -25,12 +25,18 @@ function getThreadDumpEnabled() { return threadDumpEnabled; } -function formatStatus(status, type) { +function formatStatus(status, type, row) { + if (row.isBlacklisted) { + return "Blacklisted"; + } + if (status) { - return "Active" - } else { - return "Dead" + if (row.blacklistedInStages.length == 0) { + return "Active" + } + return "Active (Blacklisted in Stages: [" + row.blacklistedInStages.join(", ") + "])"; } + return "Dead" } jQuery.extend(jQuery.fn.dataTableExt.oSort, { @@ -415,9 +421,10 @@ $(document).ready(function () { } }, {data: 'hostPort'}, - {data: 'isActive', render: function (data, type, row) { - if (row.isBlacklisted) return "Blacklisted"; - else return formatStatus (data, type); + { + data: 'isActive', + render: function (data, type, row) { + return formatStatus (data, type, row); } }, {data: 'rddBlocks'}, diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index ab01cddfca5b0..79a17e26665fd 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -213,11 +213,13 @@ private[spark] class AppStatusListener( override def onExecutorBlacklistedForStage( event: SparkListenerExecutorBlacklistedForStage): Unit = { + val now = System.nanoTime() + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - val now = System.nanoTime() - val esummary = stage.executorSummary(event.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) + setStageBlackListStatus(stage, now, event.executorId) + } + liveExecutors.get(event.executorId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } @@ -226,16 +228,29 @@ private[spark] class AppStatusListener( // Implicitly blacklist every available executor for the stage associated with this node Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - liveExecutors.values.foreach { exec => - if (exec.hostname == event.hostId) { - val esummary = stage.executorSummary(exec.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) - } - } + val executorIds = liveExecutors.values.filter(_.host == event.hostId).map(_.executorId).toSeq + setStageBlackListStatus(stage, now, executorIds: _*) + } + liveExecutors.values.filter(_.hostname == event.hostId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } + private def addBlackListedStageTo(exec: LiveExecutor, stageId: Int, now: Long): Unit = { + exec.blacklistedInStages += stageId + liveUpdate(exec, now) + } + + private def setStageBlackListStatus(stage: LiveStage, now: Long, executorIds: String*): Unit = { + executorIds.foreach { executorId => + val executorStageSummary = stage.executorSummary(executorId) + executorStageSummary.isBlacklisted = true + maybeUpdate(executorStageSummary, now) + } + stage.blackListedExecutors ++= executorIds + maybeUpdate(stage, now) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } @@ -594,12 +609,24 @@ private[spark] class AppStatusListener( stage.executorSummaries.values.foreach(update(_, now)) update(stage, now, last = true) + + val executorIdsForStage = stage.blackListedExecutors + executorIdsForStage.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) + } + } } appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } + private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { + exec.blacklistedInStages -= stageId + liveUpdate(exec, now) + } + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index d5f9e19ffdcd0..79e3f13b826ce 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -20,6 +20,7 @@ package org.apache.spark.status import java.util.Date import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.{HashSet, TreeSet} import scala.collection.mutable.HashMap import com.google.common.collect.Interners @@ -254,6 +255,7 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE var totalShuffleRead = 0L var totalShuffleWrite = 0L var isBlacklisted = false + var blacklistedInStages: Set[Int] = TreeSet() var executorLogs = Map[String, String]() @@ -299,7 +301,8 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeTime), Option(removeReason), executorLogs, - memoryMetrics) + memoryMetrics, + blacklistedInStages) new ExecutorSummaryWrapper(info) } @@ -371,6 +374,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + var blackListedExecutors = new HashSet[String]() + // Used for cleanup of tasks after they reach the configured limit. Not written to the store. @volatile var cleaning = false var savedTasks = new AtomicInteger(0) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 550eac3952bbb..a333f1aaf6325 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -95,7 +95,8 @@ class ExecutorSummary private[spark]( val removeTime: Option[Date], val removeReason: Option[String], val executorLogs: Map[String, String], - val memoryMetrics: Option[MemoryMetrics]) + val memoryMetrics: Option[MemoryMetrics], + val blacklistedInStages: Set[Int]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 942e6d8f04363..7bb8fe8fd8f98 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -19,5 +19,6 @@ "isBlacklisted" : false, "maxMemory" : 278302556, "addTime" : "2015-02-03T16:43:00.906GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index ed33c90dd39ba..dd5b1dcb7372b 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ,{ "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 73519f1d9e2e4..3e55d3d9d7eb9 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 6931fead3d2ff..e87f3e78f2dc8 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -19,7 +19,8 @@ "isBlacklisted" : false, "maxMemory" : 384093388, "addTime" : "2016-11-15T23:20:38.836GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.111:64543", @@ -44,7 +45,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.111:64539", @@ -69,7 +71,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -94,7 +97,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.111:64540", @@ -119,5 +123,6 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b74d6ee2ec836..749502709b5c8 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -273,6 +273,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.head.stageId)) + } + // Blacklisting node for stage time += 1 listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( @@ -439,6 +443,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numCompleteTasks === pending.size) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set()) + } + // Submit stage 2. time += 1 stages.last.submissionTime = Some(time) @@ -453,6 +461,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) } + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "1.example.com", + executorFailures = 1, + stageId = stages.last.stageId, + stageAttemptId = stages.last.attemptId)) + + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) + } + // Start and fail all tasks of stage 2. time += 1 val s2Tasks = createTasks(4, execIds) From 091a000d27f324de8c5c527880854ecfcf5de9a4 Mon Sep 17 00:00:00 2001 From: huangtengfei Date: Tue, 13 Feb 2018 09:59:21 -0600 Subject: [PATCH 033/593] [SPARK-23053][CORE] taskBinarySerialization and task partitions calculate in DagScheduler.submitMissingTasks should keep the same RDD checkpoint status ## What changes were proposed in this pull request? When we run concurrent jobs using the same rdd which is marked to do checkpoint. If one job has finished running the job, and start the process of RDD.doCheckpoint, while another job is submitted, then submitStage and submitMissingTasks will be called. In [submitMissingTasks](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L961), will serialize taskBinaryBytes and calculate task partitions which are both affected by the status of checkpoint, if the former is calculated before doCheckpoint finished, while the latter is calculated after doCheckpoint finished, when run task, rdd.compute will be called, for some rdds with particular partition type such as [UnionRDD](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala) who will do partition type cast, will get a ClassCastException because the part params is actually a CheckpointRDDPartition. This error occurs because rdd.doCheckpoint occurs in the same thread that called sc.runJob, while the task serialization occurs in the DAGSchedulers event loop. ## How was this patch tested? the exist uts and also add a test case in DAGScheduerSuite to show the exception case. Author: huangtengfei Closes #20244 from ivoson/branch-taskpart-mistype. --- .../apache/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 199937b8c27af..8c46a84323392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1016,15 +1016,24 @@ class DAGScheduler( // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = stage match { - case stage: ShuffleMapStage => - JavaUtils.bufferToArray( - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) - case stage: ResultStage => - JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions } taskBinary = sc.broadcast(taskBinaryBytes) @@ -1049,7 +1058,7 @@ class DAGScheduler( stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) - val part = stage.rdd.partitions(id) + val part = partitions(id) stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), @@ -1059,7 +1068,7 @@ class DAGScheduler( case stage: ResultStage => partitionsToCompute.map { id => val p: Int = stage.partitions(id) - val part = stage.rdd.partitions(p) + val part = partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, From bd24731722a9142c90cf3d76008115f308203844 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 11:39:33 -0600 Subject: [PATCH 034/593] [SPARK-23382][WEB-UI] Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. please refer to https://github.com/apache/spark/pull/20216 fix after: ![1](https://user-images.githubusercontent.com/26266482/36068644-df029328-0f14-11e8-8350-cfdde9733ffc.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20570 from guoxiaolongzte/SPARK-23382. --- .../org/apache/spark/ui/static/webui.js | 2 + .../spark/streaming/ui/StreamingPage.scala | 37 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e575c4c78970d..83009df91d30a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -80,4 +80,6 @@ $(function() { collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); + collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); + collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); }); \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 7abafd6ba7908..3a176f64cdd60 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -490,15 +490,40 @@ private[ui] class StreamingPage(parent: StreamingTab) sortBy(_.batchTime.milliseconds).reverse val activeBatchesContent = { -

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ - new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Active Batches ({runningBatches.size + waitingBatches.size}) +

    +
    +
    + {new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } val completedBatchesContent = { -

    - Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) -

    ++ - new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Completed Batches (last {completedBatches.size} + out of {listener.numTotalCompletedBatches}) +

    +
    +
    + {new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } activeBatchesContent ++ completedBatchesContent From 263531466f4a7e223c94caa8705e6e8394a12054 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 13 Feb 2018 11:45:20 -0600 Subject: [PATCH 035/593] [SPARK-23392][TEST] Add some test cases for images feature ## What changes were proposed in this pull request? Add some test cases for images feature ## How was this patch tested? Add some test cases in ImageSchemaSuite Author: xubo245 <601450868@qq.com> Closes #20583 from xubo245/CARBONDATA23392_AddTestForImage. --- .../spark/ml/image/ImageSchemaSuite.scala | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index a8833c615865d..527b3f8955968 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -65,11 +65,71 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(count50 > 0 && count50 < countTotal) } + test("readImages test: recursive = false") { + val df = readImages(imagePath, null, false, 3, true, 1.0, 0) + assert(df.count() === 0) + } + + test("readImages test: read jpg image") { + val df = readImages(imagePath + "/kittens/DP153539.jpg", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read png image") { + val df = readImages(imagePath + "/multi-channel/BGRA.png", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read non image") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, true, 1.0, 0) + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + assert(df.count() === 0) + } + + test("readImages test: read non image and dropImageFailures is false") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, false, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: sampleRatio > 1") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, 1.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio < 0") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, -0.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio = 0") { + val df = readImages(imagePath, null, true, 3, true, 0.0, 0) + assert(df.count() === 0) + } + + test("readImages test: with sparkSession") { + val df = readImages(imagePath, sparkSession = spark, true, 3, true, 1.0, 0) + assert(df.count() === 8) + } + test("readImages partition test") { val df = readImages(imagePath, null, true, 3, true, 1.0, 0) assert(df.rdd.getNumPartitions === 3) } + test("readImages partition test: < 0") { + val df = readImages(imagePath, null, true, -3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + + test("readImages partition test: = 0") { + val df = readImages(imagePath, null, true, 0, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + // Images with the different number of channels test("readImages pixel values test") { @@ -93,7 +153,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // - default representation for 3-channel RGB images is BGR row-wise: // (B00, G00, R00, B10, G10, R10, ...) // - default representation for 4-channel RGB images is BGRA row-wise: - // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) private val firstBytes20 = Map( "grayscale.jpg" -> (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, From 05d051293fe46938e9cb012342fea6e8a3715cd4 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Tue, 13 Feb 2018 09:49:52 -0800 Subject: [PATCH 036/593] [SPARK-23316][SQL] AnalysisException after max iteration reached for IN query ## What changes were proposed in this pull request? Added flag ignoreNullability to DataType.equalsStructurally. The previous semantic is for ignoreNullability=false. When ignoreNullability=true equalsStructurally ignores nullability of contained types (map key types, value types, array element types, structure field types). In.checkInputTypes calls equalsStructurally to check if the children types match. They should match regardless of nullability (which is just a hint), so it is now called with ignoreNullability=true. ## How was this patch tested? New test in SubquerySuite Author: Bogdan Raducanu Closes #20548 from bogdanrdc/SPARK-23316. --- .../sql/catalyst/expressions/predicates.scala | 3 ++- .../org/apache/spark/sql/types/DataType.scala | 18 ++++++++++++------ .../org/apache/spark/sql/SubquerySuite.scala | 5 +++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b469f5cb7586a..a6d41ea7d00d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -157,7 +157,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, + ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d6e0df12218ad..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -295,25 +295,31 @@ object DataType { } /** - * Returns true if the two data types share the same "shape", i.e. the types (including - * nullability) are the same, but the field names don't need to be the same. + * Returns true if the two data types share the same "shape", i.e. the types + * are the same, but the field names don't need to be the same. + * + * @param ignoreNullability whether to ignore nullability when comparing the types */ - def equalsStructurally(from: DataType, to: DataType): Boolean = { + def equalsStructurally( + from: DataType, + to: DataType, + ignoreNullability: Boolean = false): Boolean = { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurally(left.elementType, right.elementType) && - left.containsNull == right.containsNull + (ignoreNullability || left.containsNull == right.containsNull) case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType) && equalsStructurally(left.valueType, right.valueType) && - left.valueContainsNull == right.valueContainsNull + (ignoreNullability || left.valueContainsNull == right.valueContainsNull) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields) .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + equalsStructurally(l.dataType, r.dataType) && + (ignoreNullability || l.nullable == r.nullable) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 8673dc14f7597..31e8b0e8dede0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -950,4 +950,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(join.duplicateResolved) assert(optimizedPlan.resolved) } + + test("SPARK-23316: AnalysisException after max iteration reached for IN query") { + // before the fix this would throw AnalysisException + spark.range(10).where("(id,id) in (select id, null from range(3))").count + } } From 4e0fb010ccdf13fe411f2a4796bbadc385b01520 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 13 Feb 2018 11:51:19 -0600 Subject: [PATCH 037/593] [SPARK-23217][ML] Add cosine distance measure to ClusteringEvaluator ## What changes were proposed in this pull request? The PR provided an implementation of ClusteringEvaluator using the cosine distance measure. This allows to evaluate clustering results created using the cosine distance, introduced in SPARK-22119. In the corresponding JIRA, there is a design document for the algorithm implemented here. ## How was this patch tested? Added UT which compares the result to the one provided by python sklearn. Author: Marco Gaido Closes #20396 from mgaido91/SPARK-23217. --- .../ml/evaluation/ClusteringEvaluator.scala | 334 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 32 +- 2 files changed, 300 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index d6ec5223237bb..8d4ae562b3d2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,11 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, + SchemaUtils} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -32,15 +33,11 @@ import org.apache.spark.sql.types.DoubleType * :: Experimental :: * * Evaluator for clustering results. - * The metric computes the Silhouette measure - * using the squared Euclidean distance. - * - * The Silhouette is a measure for the validation - * of the consistency within clusters. It ranges - * between 1 and -1, where a value close to 1 - * means that the points in a cluster are close - * to the other points in the same cluster and - * far from the points of the other clusters. + * The metric computes the Silhouette measure using the specified distance measure. + * + * The Silhouette is a measure for the validation of the consistency within clusters. It ranges + * between 1 and -1, where a value close to 1 means that the points in a cluster are close to the + * other points in the same cluster and far from the points of the other clusters. */ @Experimental @Since("2.3.0") @@ -84,18 +81,40 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") def setMetricName(value: String): this.type = set(metricName, value) - setDefault(metricName -> "silhouette") + /** + * param for distance measure to be used in evaluation + * (supports `"squaredEuclidean"` (default), `"cosine"`) + * @group param + */ + @Since("2.4.0") + val distanceMeasure: Param[String] = { + val availableValues = Array("squaredEuclidean", "cosine") + val allowedParams = ParamValidators.inArray(availableValues) + new Param(this, "distanceMeasure", "distance measure in evaluation. Supported options: " + + availableValues.mkString("'", "', '", "'"), allowedParams) + } + + /** @group getParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + + /** @group setParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + + setDefault(metricName -> "silhouette", distanceMeasure -> "squaredEuclidean") @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) - $(metricName) match { - case "silhouette" => + ($(metricName), $(distanceMeasure)) match { + case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol) - ) + dataset, $(predictionCol), $(featuresCol)) + case ("silhouette", "cosine") => + CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) } } } @@ -111,6 +130,48 @@ object ClusteringEvaluator } +private[evaluation] abstract class Silhouette { + + /** + * It computes the Silhouette coefficient for a point. + */ + def pointSilhouetteCoefficient( + clusterIds: Set[Double], + pointClusterId: Double, + pointClusterNumOfPoints: Long, + averageDistanceToCluster: (Double) => Double): Double = { + // Here we compute the average dissimilarity of the current point to any cluster of which the + // point is not a member. + // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current + // point - is said to be the "neighboring cluster". + val otherClusterIds = clusterIds.filter(_ != pointClusterId) + val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min + + // adjustment for excluding the node itself from the computation of the average dissimilarity + val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) { + 0.0 + } else { + averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints / + (pointClusterNumOfPoints - 1) + } + + if (currentClusterDissimilarity < neighboringClusterDissimilarity) { + 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) + } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { + (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 + } else { + 0.0 + } + } + + /** + * Compute the mean Silhouette values of all samples. + */ + def overallScore(df: DataFrame, scoreColumn: Column): Double = { + df.select(avg(scoreColumn)).collect()(0).getDouble(0) + } +} + /** * SquaredEuclideanSilhouette computes the average of the * Silhouette over all the data of the dataset, which is @@ -259,7 +320,7 @@ object ClusteringEvaluator * `N` is the number of points in the dataset and `W` is the number * of worker nodes. */ -private[evaluation] object SquaredEuclideanSilhouette { +private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { private[this] var kryoRegistrationPerformed: Boolean = false @@ -336,18 +397,19 @@ private[evaluation] object SquaredEuclideanSilhouette { * It computes the Silhouette coefficient for a point. * * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point. + * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. * @param clusterId The id of the cluster the current point belongs to. * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. * @return The Silhouette for the point. */ def computeSilhouetteCoefficient( broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], - features: Vector, + point: Vector, clusterId: Double, squaredNorm: Double): Double = { - def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = { + def compute(targetClusterId: Double): Double = { + val clusterStats = broadcastedClustersMap.value(targetClusterId) val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) squaredNorm + @@ -355,41 +417,14 @@ private[evaluation] object SquaredEuclideanSilhouette { 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints } - // Here we compute the average dissimilarity of the - // current point to any cluster of which the point - // is not a member. - // The cluster with the lowest average dissimilarity - // - i.e. the nearest cluster to the current point - - // is said to be the "neighboring cluster". - var neighboringClusterDissimilarity = Double.MaxValue - broadcastedClustersMap.value.keySet.foreach { - c => - if (c != clusterId) { - val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c)) - if(dissimilarity < neighboringClusterDissimilarity) { - neighboringClusterDissimilarity = dissimilarity - } - } - } - val currentCluster = broadcastedClustersMap.value(clusterId) - // adjustment for excluding the node itself from - // the computation of the average dissimilarity - val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) { - 0 - } else { - compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints / - (currentCluster.numOfPoints - 1) - } - - (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match { - case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) - case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 - case 0 => 0.0 - } + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId).numOfPoints, + compute) } /** - * Compute the mean Silhouette values of all samples. + * Compute the Silhouette score of the dataset using squared Euclidean distance measure. * * @param dataset The input dataset (previously clustered) on which compute the Silhouette. * @param predictionCol The name of the column which contains the predicted cluster id @@ -412,7 +447,7 @@ private[evaluation] object SquaredEuclideanSilhouette { val clustersStatsMap = SquaredEuclideanSilhouette .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol) - // Silhouette is reasonable only when the number of clusters is grater then 1 + // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) @@ -421,13 +456,190 @@ private[evaluation] object SquaredEuclideanSilhouette { computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) } - val silhouetteScore = dfWithSquaredNorm - .select(avg( - computeSilhouetteCoefficientUDF( - col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm")) - )) - .collect()(0) - .getDouble(0) + val silhouetteScore = overallScore(dfWithSquaredNorm, + computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), + col("squaredNorm"))) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} + + +/** + * The algorithm which is implemented in this object, instead, is an efficient and parallel + * implementation of the Silhouette using the cosine distance measure. The cosine distance + * measure is defined as `1 - s` where `s` is the cosine similarity between two points. + * + * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` + * is: + * + *
    + * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) + * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} + * \frac{c_{ij}}{\|C_{i}\|} + * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N + * \frac{c_{ij}}{\|C_{i}\|} \Big) + * $$ + *
    + * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension + * of the `i`-th point in cluster `$\Gamma$`. + * + * Then, we can define the vector: + * + *
    + * $$ + * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed for each point and the vector + * + *
    + * $$ + * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. + * + * With these definitions, the numerator becomes: + * + *
    + * $$ + * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} + * $$ + *
    + * + * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: + * + *
    + * $$ + * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} + * $$ + *
    + * + * In the implementation, the precomputed values for the clusters are distributed among the worker + * nodes via broadcasted variables, because we can assume that the clusters are limited in number. + * + * The main strengths of this algorithm are the low computational complexity and the intrinsic + * parallelism. The precomputed information for each point and for each cluster can be computed + * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the + * dataset and `W` is the number of worker nodes. After that, every point can be analyzed + * independently from the others. + * + * For every point we need to compute the average distance to all the clusters. Since the formula + * above requires `O(D)` operations, this phase has a computational complexity which is + * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number + * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker + * nodes. + */ +private[evaluation] object CosineSilhouette extends Silhouette { + + private[this] val normalizedFeaturesColName = "normalizedFeatures" + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a + * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). + */ + def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) + .rdd + .map { row => (row.getDouble(0), row.getAs[Vector](1)) } + .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))( + seqOp = { + case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) => + BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum) + (normalizedFeaturesSum, numOfPoints + 1) + }, + combOp = { + case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) => + BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) + (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2) + } + ) + + clustersStatsRDD + .collectAsMap() + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the + * normalized features of the current point. + * @param clusterId The id of the cluster the current point belongs to. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]], + normalizedFeatures: Vector, + clusterId: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) + 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId)._2, + compute) + } + + /** + * Compute the Silhouette score of the dataset using the cosine distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String): Double = { + val normalizeFeatureUDF = udf { + features: Vector => { + val norm = Vectors.norm(features, 2.0) + features match { + case d: DenseVector => Vectors.dense(d.values.map(_ / norm)) + case s: SparseVector => Vectors.sparse(s.size, s.indices, s.values.map(_ / norm)) + } + } + } + val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, + normalizeFeatureUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double) + } + + val silhouetteScore = overallScore(dfWithNormalizedFeatures, + computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), + col(predictionCol).cast(DoubleType))) bClustersStatsMap.destroy() diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 677ce49a903ab..3bf34770f5687 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -66,16 +66,38 @@ class ClusteringEvaluatorSuite assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } - test("number of clusters must be greater than one") { - val singleClusterDataset = irisDataset.where($"label" === 0.0) + /* + Use the following python code to load the data and evaluate it using scikit-learn package. + + from sklearn import datasets + from sklearn.metrics import silhouette_score + iris = datasets.load_iris() + round(silhouette_score(iris.data, iris.target, metric='cosine'), 10) + + 0.7222369298 + */ + test("cosine Silhouette") { val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") + .setDistanceMeasure("cosine") + + assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + } + + test("number of clusters must be greater than one") { + val singleClusterDataset = irisDataset.where($"label" === 0.0) + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) - val e = intercept[AssertionError]{ - evaluator.evaluate(singleClusterDataset) + val e = intercept[AssertionError] { + evaluator.evaluate(singleClusterDataset) + } + assert(e.getMessage.contains("Number of clusters must be greater than one")) } - assert(e.getMessage.contains("Number of clusters must be greater than one")) } } From d58fe28836639e68e262812d911f167cb071007b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 13 Feb 2018 11:18:45 -0800 Subject: [PATCH 038/593] [SPARK-23154][ML][DOC] Document backwards compatibility guarantees for ML persistence ## What changes were proposed in this pull request? Added documentation about what MLlib guarantees in terms of loading ML models and Pipelines from old Spark versions. Discussed & confirmed on linked JIRA. Author: Joseph K. Bradley Closes #20592 from jkbradley/SPARK-23154-backwards-compat-doc. --- docs/ml-pipeline.md | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index aa92c0a37c0f4..e22e9003c30f6 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -## Saving and Loading Pipelines +## ML persistence: Saving and Loading Pipelines -Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. +As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage. + +ML persistence works across Scala, Java and Python. However, R currently uses a modified format, +so models saved in R can only be loaded back in R; this should be fixed in the future and is +tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572). + +### Backwards compatibility for ML persistence + +In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML +model or Pipeline in one version of Spark, then you should be able to load it back and use it in a +future version of Spark. However, there are rare exceptions, described below. + +Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark +version X loadable by Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Yes; these are backwards compatible. +* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible. + +Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Identical behavior, except for bug fixes. + +For both model persistence and model behavior, any breaking changes across a minor version or patch +version are reported in the Spark version release notes. If a breakage is not reported in release +notes, then it should be treated as a bug to be fixed. # Code examples From 2ee76c22b6e48e643694c9475e5f0d37124215e7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 11:56:49 -0800 Subject: [PATCH 039/593] [SPARK-23400][SQL] Add a constructors for ScalaUDF ## What changes were proposed in this pull request? In this upcoming 2.3 release, we changed the interface of `ScalaUDF`. Unfortunately, some Spark packages (e.g., spark-deep-learning) are using our internal class `ScalaUDF`. In the release 2.3, we added new parameters into this class. The users hit the binary compatibility issues and got the exception: ``` > java.lang.NoSuchMethodError: org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(Ljava/lang/Object;Lorg/apache/spark/sql/types/DataType;Lscala/collection/Seq;Lscala/collection/Seq;Lscala/Option;)V ``` This PR is to improve the backward compatibility. However, we definitely should not encourage the external packages to use our internal classes. This might make us hard to maintain/develop the codes in Spark. ## How was this patch tested? N/A Author: gatorsmile Closes #20591 from gatorsmile/scalaUDF. --- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 388ef42883ad3..989c02305620a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,6 +49,17 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + // The constructor for SPARK 2.1 and 2.2 + def this( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType], + udfName: Option[String]) = { + this( + function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + } + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = From a5a4b83501526e02d0e3cd0056e4a5c0e1c8284f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 16:46:43 -0600 Subject: [PATCH 040/593] [SPARK-23235][CORE] Add executor Threaddump to api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending api with the executor thread dump data. For this new REST URL is introduced: - GET http://localhost:4040/api/v1/applications/{applicationId}/executors/{executorId}/threads
    Example response: ``` javascript [ { "threadId" : 52, "threadName" : "context-cleaner-periodic-gc", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1385411893})", "holdingLocks" : [ ] }, { "threadId" : 48, "threadName" : "dag-scheduler-event-loop", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingDeque.takeFirst(LinkedBlockingDeque.java:492)\njava.util.concurrent.LinkedBlockingDeque.take(LinkedBlockingDeque.java:680)\norg.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:46)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1138053349})", "holdingLocks" : [ ] }, { "threadId" : 17, "threadName" : "dispatcher-event-loop-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker832743930})" ] }, { "threadId" : 18, "threadName" : "dispatcher-event-loop-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker834153999})" ] }, { "threadId" : 19, "threadName" : "dispatcher-event-loop-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker664836465})" ] }, { "threadId" : 20, "threadName" : "dispatcher-event-loop-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1645557354})" ] }, { "threadId" : 21, "threadName" : "dispatcher-event-loop-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1188871851})" ] }, { "threadId" : 22, "threadName" : "dispatcher-event-loop-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker920926249})" ] }, { "threadId" : 23, "threadName" : "dispatcher-event-loop-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker355222677})" ] }, { "threadId" : 24, "threadName" : "dispatcher-event-loop-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1589745212})" ] }, { "threadId" : 49, "threadName" : "driver-heartbeater", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1602885835})", "holdingLocks" : [ ] }, { "threadId" : 53, "threadName" : "element-tracking-store-worker", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1439439099})", "holdingLocks" : [ ] }, { "threadId" : 3, "threadName" : "Finalizer", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:164)\njava.lang.ref.Finalizer$FinalizerThread.run(Finalizer.java:209)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1213098236})", "holdingLocks" : [ ] }, { "threadId" : 15, "threadName" : "ForkJoinPool-1-worker-13", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\nscala.concurrent.forkjoin.ForkJoinPool.scan(ForkJoinPool.java:2075)\nscala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)\nscala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)", "blockedByLock" : "Lock(scala.concurrent.forkjoin.ForkJoinPool380286413})", "holdingLocks" : [ ] }, { "threadId" : 45, "threadName" : "heartbeat-receiver-event-loop-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject715135812})", "holdingLocks" : [ ] }, { "threadId" : 1, "threadName" : "main", "threadState" : "RUNNABLE", "stackTrace" : "java.io.FileInputStream.read0(Native Method)\njava.io.FileInputStream.read(FileInputStream.java:207)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:169) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:137)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:246)\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:261) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:198) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.console.ConsoleReader.readCharacter(ConsoleReader.java:2145)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2349)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2269)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readOneLine(JLineReader.scala:57)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$.restartSysCalls(InteractiveReader.scala:44)\nscala.tools.nsc.interpreter.InteractiveReader$class.readLine(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readLine(JLineReader.scala:28)\nscala.tools.nsc.interpreter.ILoop.readOneLine(ILoop.scala:404)\nscala.tools.nsc.interpreter.ILoop.loop(ILoop.scala:413)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply$mcZ$sp(ILoop.scala:923)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.reflect.internal.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:97)\nscala.tools.nsc.interpreter.ILoop.process(ILoop.scala:909)\norg.apache.spark.repl.Main$.doMain(Main.scala:76)\norg.apache.spark.repl.Main$.main(Main.scala:56)\norg.apache.spark.repl.Main.main(Main.scala)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52)\norg.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:879)\norg.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197)\norg.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227)\norg.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136)\norg.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})" ] }, { "threadId" : 26, "threadName" : "map-output-dispatcher-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1791280119})" ] }, { "threadId" : 27, "threadName" : "map-output-dispatcher-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1947378744})" ] }, { "threadId" : 28, "threadName" : "map-output-dispatcher-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker507507251})" ] }, { "threadId" : 29, "threadName" : "map-output-dispatcher-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1016408627})" ] }, { "threadId" : 30, "threadName" : "map-output-dispatcher-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1879219501})" ] }, { "threadId" : 31, "threadName" : "map-output-dispatcher-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker290509937})" ] }, { "threadId" : 32, "threadName" : "map-output-dispatcher-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1889468930})" ] }, { "threadId" : 33, "threadName" : "map-output-dispatcher-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1699637904})" ] }, { "threadId" : 47, "threadName" : "netty-rpc-env-timeout", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject977194847})", "holdingLocks" : [ ] }, { "threadId" : 14, "threadName" : "NonBlockingInputStreamThread", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.run(NonBlockingInputStream.java:278)\njava.lang.Thread.run(Thread.java:748)", "blockedByThreadId" : 1, "blockedByLock" : "Lock(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})", "holdingLocks" : [ ] }, { "threadId" : 2, "threadName" : "Reference Handler", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.lang.ref.Reference.tryHandlePending(Reference.java:191)\njava.lang.ref.Reference$ReferenceHandler.run(Reference.java:153)", "blockedByLock" : "Lock(java.lang.ref.Reference$Lock1359433302})", "holdingLocks" : [ ] }, { "threadId" : 35, "threadName" : "refresh progress", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.util.TimerThread.mainLoop(Timer.java:552)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue44276328})", "holdingLocks" : [ ] }, { "threadId" : 34, "threadName" : "RemoteBlock-temp-file-clean-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager.org$apache$spark$storage$BlockManager$RemoteBlockTempFileManager$$keepCleaning(BlockManager.scala:1630)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager$$anon$1.run(BlockManager.scala:1608)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock391748181})", "holdingLocks" : [ ] }, { "threadId" : 25, "threadName" : "rpc-server-3-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet1066929256})", "Monitor(java.util.Collections$UnmodifiableSet561426729})", "Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})" ] }, { "threadId" : 50, "threadName" : "shuffle-server-5-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet385972319})", "Monitor(java.util.Collections$UnmodifiableSet477937109})", "Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})" ] }, { "threadId" : 4, "threadName" : "Signal Dispatcher", "threadState" : "RUNNABLE", "stackTrace" : "", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 51, "threadName" : "Spark Context Cleaner", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.ContextCleaner$$anonfun$org$apache$spark$ContextCleaner$$keepCleaning$1.apply$mcV$sp(ContextCleaner.scala:181)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.ContextCleaner.org$apache$spark$ContextCleaner$$keepCleaning(ContextCleaner.scala:178)\norg.apache.spark.ContextCleaner$$anon$1.run(ContextCleaner.scala:73)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1739420764})", "holdingLocks" : [ ] }, { "threadId" : 16, "threadName" : "spark-listener-group-appStatus", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1287190987})", "holdingLocks" : [ ] }, { "threadId" : 44, "threadName" : "spark-listener-group-executorManagement", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject943262890})", "holdingLocks" : [ ] }, { "threadId" : 54, "threadName" : "spark-listener-group-shared", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject334604425})", "holdingLocks" : [ ] }, { "threadId" : 37, "threadName" : "SparkUI-37", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 38, "threadName" : "SparkUI-38", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl841741934})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3873523986})", "Monitor(java.util.Collections$UnmodifiableSet1769333189})", "Monitor(sun.nio.ch.KQueueSelectorImpl841741934})" ] }, { "threadId" : 40, "threadName" : "SparkUI-40-acceptor-034929380-Spark3a557b62{HTTP/1.1,[http/1.1]}{0.0.0.0:4040}", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.ServerSocketChannelImpl.accept0(Native Method)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:422)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:250) => holding Monitor(java.lang.Object1134240909})\norg.spark_project.jetty.server.ServerConnector.accept(ServerConnector.java:371)\norg.spark_project.jetty.server.AbstractConnector$Acceptor.run(AbstractConnector.java:601)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(java.lang.Object1134240909})" ] }, { "threadId" : 43, "threadName" : "SparkUI-43", "threadState" : "RUNNABLE", "stackTrace" : "sun.management.ThreadImpl.dumpThreads0(Native Method)\nsun.management.ThreadImpl.dumpAllThreads(ThreadImpl.java:454)\norg.apache.spark.util.Utils$.getThreadDump(Utils.scala:2170)\norg.apache.spark.SparkContext.getExecutorThreadDump(SparkContext.scala:596)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:66)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:65)\nscala.Option.flatMap(Option.scala:171)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:65)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:58)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:139)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:134)\norg.apache.spark.ui.SparkUI.withSparkUI(SparkUI.scala:106)\norg.apache.spark.status.api.v1.BaseAppResource$class.withUI(ApiRootResource.scala:134)\norg.apache.spark.status.api.v1.AbstractApplicationResource.withUI(OneApplicationResource.scala:32)\norg.apache.spark.status.api.v1.AbstractApplicationResource.threadDump(OneApplicationResource.scala:58)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.glassfish.jersey.server.model.internal.ResourceMethodInvocationHandlerFactory$1.invoke(ResourceMethodInvocationHandlerFactory.java:81)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher$1.run(AbstractJavaResourceMethodDispatcher.java:144)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.invoke(AbstractJavaResourceMethodDispatcher.java:161)\norg.glassfish.jersey.server.model.internal.JavaResourceMethodDispatcherProvider$TypeOutInvoker.doDispatch(JavaResourceMethodDispatcherProvider.java:205)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.dispatch(AbstractJavaResourceMethodDispatcher.java:99)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.invoke(ResourceMethodInvoker.java:389)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:347)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:102)\norg.glassfish.jersey.server.ServerRuntime$2.run(ServerRuntime.java:326)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:271)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:267)\norg.glassfish.jersey.internal.Errors.process(Errors.java:315)\norg.glassfish.jersey.internal.Errors.process(Errors.java:297)\norg.glassfish.jersey.internal.Errors.process(Errors.java:267)\norg.glassfish.jersey.process.internal.RequestScope.runInScope(RequestScope.java:317)\norg.glassfish.jersey.server.ServerRuntime.process(ServerRuntime.java:305)\norg.glassfish.jersey.server.ApplicationHandler.handle(ApplicationHandler.java:1154)\norg.glassfish.jersey.servlet.WebComponent.serviceImpl(WebComponent.java:473)\norg.glassfish.jersey.servlet.WebComponent.service(WebComponent.java:427)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:388)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:341)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:228)\norg.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848)\norg.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584)\norg.spark_project.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180)\norg.spark_project.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512)\norg.spark_project.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112)\norg.spark_project.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)\norg.spark_project.jetty.server.handler.gzip.GzipHandler.handle(GzipHandler.java:493)\norg.spark_project.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213)\norg.spark_project.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134)\norg.spark_project.jetty.server.Server.handle(Server.java:534)\norg.spark_project.jetty.server.HttpChannel.handle(HttpChannel.java:320)\norg.spark_project.jetty.server.HttpConnection.onFillable(HttpConnection.java:251)\norg.spark_project.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283)\norg.spark_project.jetty.io.FillInterest.fillable(FillInterest.java:108)\norg.spark_project.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 67, "threadName" : "SparkUI-67", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3881415814})", "Monitor(java.util.Collections$UnmodifiableSet62050480})", "Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})" ] }, { "threadId" : 68, "threadName" : "SparkUI-68", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl223607814})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3543145185})", "Monitor(java.util.Collections$UnmodifiableSet897441546})", "Monitor(sun.nio.ch.KQueueSelectorImpl223607814})" ] }, { "threadId" : 71, "threadName" : "SparkUI-71", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 77, "threadName" : "SparkUI-77", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 78, "threadName" : "SparkUI-78", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl403077801})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3261312406})", "Monitor(java.util.Collections$UnmodifiableSet852901260})", "Monitor(sun.nio.ch.KQueueSelectorImpl403077801})" ] }, { "threadId" : 72, "threadName" : "SparkUI-JettyScheduler", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1587346642})", "holdingLocks" : [ ] }, { "threadId" : 63, "threadName" : "task-result-getter-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 64, "threadName" : "task-result-getter-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 65, "threadName" : "task-result-getter-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 66, "threadName" : "task-result-getter-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 46, "threadName" : "Timer-0", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.util.TimerThread.mainLoop(Timer.java:526)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue635634547})", "holdingLocks" : [ ] } ] ```
    ## How was this patch tested? It was tested manually. Old executor page with thread dumps: screen shot 2018-02-01 at 14 31 19 New api: screen shot 2018-02-01 at 14 31 56 Testing error cases. Initial state: ![screen shot 2018-02-06 at 13 05 05](https://user-images.githubusercontent.com/2017933/35858990-ad2982be-0b3e-11e8-879b-656112065c7f.png) Dead executor: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/1/threads Executor is not active. 400 ``` Never existed (but well formatted: number) executor ID: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/42/threads Executor does not exist. 404 ``` Not available stacktrace (dead executor but UI has not registered as dead yet): ```bash $ kill -9 ; curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/2/threads No thread dump is available. 404 ``` Invalid executor ID format: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/something6/threads Invalid executorId: neither 'driver' nor number. 400 ``` Author: “attilapiros” Closes #20474 from attilapiros/SPARK-23235. --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../spark/status/api/v1/ApiRootResource.scala | 8 +++++ .../api/v1/OneApplicationResource.scala | 29 +++++++++++++++-- .../org/apache/spark/status/api/v1/api.scala | 9 ++++++ .../ui/exec/ExecutorThreadDumpPage.scala | 13 +------- .../apache/spark/util/ThreadStackTrace.scala | 31 ------------------- .../scala/org/apache/spark/util/Utils.scala | 18 ++++++++++- docs/monitoring.md | 7 +++++ 8 files changed, 69 insertions(+), 47 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c4f74c4f1f9c2..dc531e3337014 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index ed9bdc6e1e3c2..7127397f6205c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -157,6 +157,14 @@ private[v1] class NotFoundException(msg: String) extends WebApplicationException .build() ) +private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( + new ServiceUnavailableException(msg), + Response + .status(Response.Status.SERVICE_UNAVAILABLE) + .entity(ErrorWrapper(msg)) + .build() +) + private[v1] class BadParameterException(msg: String) extends WebApplicationException( new IllegalArgumentException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index bd4df07e7afc6..974697890dd03 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -19,13 +19,13 @@ package org.apache.spark.status.api.v1 import java.io.OutputStream import java.util.{List => JList} import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs._ import javax.ws.rs.core.{MediaType, Response, StreamingOutput} import scala.util.control.NonFatal -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI +import org.apache.spark.{JobExecutionStatus, SparkContext} +import org.apache.spark.ui.UIUtils @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AbstractApplicationResource extends BaseAppResource { @@ -51,6 +51,29 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { @Path("executors") def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + @GET + @Path("executors/{executorId}/threads") + def threadDump(@PathParam("executorId") execId: String): Array[ThreadStackTrace] = withUI { ui => + if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) { + throw new BadParameterException( + s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.") + } + + val safeSparkContext = ui.sc.getOrElse { + throw new ServiceUnavailable("Thread dumps not available through the history server.") + } + + ui.store.asOption(ui.store.executorSummary(execId)) match { + case Some(executorSummary) if executorSummary.isActive => + val safeThreadDump = safeSparkContext.getExecutorThreadDump(execId).getOrElse { + throw new NotFoundException("No thread dump is available.") + } + safeThreadDump + case Some(_) => throw new BadParameterException("Executor is not active.") + case _ => throw new NotFoundException("Executor does not exist.") + } + } + @GET @Path("allexecutors") def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index a333f1aaf6325..369e98b683b1a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -316,3 +316,12 @@ class RuntimeInfo private[spark]( val javaVersion: String, val javaHome: String, val scalaVersion: String) + +case class ThreadStackTrace( + val threadId: Long, + val threadName: String, + val threadState: Thread.State, + val stackTrace: String, + val blockedByThreadId: Option[Long], + val blockedByLock: String, + val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f4686ea3cf91f..7a9aaf29a8b05 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -41,17 +40,7 @@ private[ui] class ExecutorThreadDumpPage( val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => - val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 - val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 - if (v1 == v2) { - threadTrace1.threadName.toLowerCase(Locale.ROOT) < - threadTrace2.threadName.toLowerCase(Locale.ROOT) - } else { - v1 > v2 - } - }.map { thread => + val dumpRows = threadDump.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { case Some(_) => diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala deleted file mode 100644 index b1217980faf1f..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Used for shipping per-thread stacktraces from the executors to driver. - */ -private[spark] case class ThreadStackTrace( - threadId: Long, - threadName: String, - threadState: Thread.State, - stackTrace: String, - blockedByThreadId: Option[Long], - blockedByLock: String, - holdingLocks: Seq[String]) - diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5853302973140..d493663f0b168 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -63,6 +63,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.status.api.v1.ThreadStackTrace /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2168,7 +2169,22 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + threadInfos.sortWith { case (threadTrace1, threadTrace2) => + val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT) + val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT) + val nameCmpRes = name1.compareTo(name2) + if (nameCmpRes == 0) { + threadTrace1.getThreadId < threadTrace2.getThreadId + } else { + nameCmpRes < 0 + } + } else { + v1 > v2 + } + }.map(threadInfoToThreadStackTrace) } def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { diff --git a/docs/monitoring.md b/docs/monitoring.md index 6f6cfc1288d73..d5f7ffcc260a1 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -347,6 +347,13 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/executors A list of all active executors for the given application. + + /applications/[app-id]/executors/[executor-id]/threads + + Stack traces of all the threads running within the given active executor. + Not available via the history server. + + /applications/[app-id]/allexecutors A list of all(active and dead) executors for the given application. From d6f5e172b480c62165be168deae0deff8062f476 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 16:21:17 -0800 Subject: [PATCH 041/593] Revert "[SPARK-23303][SQL] improve the explain result for data source v2 relations" This reverts commit f17b936f0ddb7d46d1349bd42f9a64c84c06e48d. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +++- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 +++++++++++++ .../v2/DataSourceV2QueryPlan.scala | 96 ------------------- .../datasources/v2/DataSourceV2Relation.scala | 26 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 ++- 15 files changed, 127 insertions(+), 157 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 72ee0c551ec3d..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,9 +17,20 @@ package org.apache.spark.sql.kafka010 -import org.apache.spark.sql.Dataset +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -60,8 +71,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index d34458ac81014..5a1a14f7a307a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,8 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index cb09cce75ff6f..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,8 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 984b6510f2dbe..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,9 +189,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.newInstance() val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -219,7 +221,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..81219e9771bd8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The held data source reader. + */ + def reader: DataSourceReader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(output, reader.getClass, filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala deleted file mode 100644 index 1e0d088f3a57c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.util.Utils - -/** - * A base class for data source v2 related query plan(both logical and physical). It defines the - * equals/hashCode methods, and provides a string representation of the query plan, according to - * some common information. - */ -trait DataSourceV2QueryPlan { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. - */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } - - /** - * The metadata of this data source query plan that can be used for equality check. - */ - private def metadata: Seq[Any] = Seq(output, source.getClass, filters) - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } - - def metadataString: String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") - - val outputStr = Utils.truncatedString(output, "[", ", ", "]") - - val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) - }, " (", ", ", ")") - } else { - "" - } - - s"${source.getClass.getSimpleName}$outputStr$entriesStr" - } - - private def redact(text: String): String = { - Utils.redact(SQLConf.get.stringRedationPattern, text) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cd97e0cab6b5c..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,23 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - source: DataSourceV2, - reader: DataSourceReader, - override val isStreaming: Boolean) - extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] - override def simpleString: String = { - val streamingHeader = if (isStreaming) "Streaming " else "" - s"${streamingHeader}Relation $metadataString" - } - override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -49,8 +41,18 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { - def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) + def apply(reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c99d535efcf81..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,14 +36,11 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], - @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def simpleString: String = s"Scan $metadataString" - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fb61e6f32b1f4..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 4cfdd50e8f46b..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r: DataSourceV2Relation) => + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = r.reader match { + val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84564b6639ac9..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,8 +52,6 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] - private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -92,7 +90,6 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -408,15 +405,12 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange(toJava(current), Optional.of(availableV2)) + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> new DataSourceV2Relation( - reader.readSchema().toAttributes, - // Provide a fake value here just in case something went wrong, e.g. the reader gives - // a wrong `equals` implementation. - readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), - reader, - isStreaming = true)) + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -506,5 +500,3 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } - -object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f87d57d0b3209..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(ds, _, output) => + case ContinuousExecutionRelation(_, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,8 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => - r.reader.asInstanceOf[ContinuousReader] + case DataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 70eb9f0ac66d5..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 254394685857b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case d: DataSourceV2Relation => d.reader + case DataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9ee9aaf87f87c..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.streaming.continuous -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import java.util.UUID + +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -40,7 +43,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 357babde5a8eb9710de7016d7ae82dee21fa4ef3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 14 Feb 2018 10:55:24 +0800 Subject: [PATCH 042/593] [SPARK-23399][SQL] Register a task completion listener first for OrcColumnarBatchReader ## What changes were proposed in this pull request? This PR aims to resolve an open file leakage issue reported at [SPARK-23390](https://issues.apache.org/jira/browse/SPARK-23390) by moving the listener registration position. Currently, the sequence is like the following. 1. Create `batchReader` 2. `batchReader.initialize` opens a ORC file. 3. `batchReader.initBatch` may take a long time to alloc memory in some environment and cause errors. 4. `Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))` This PR moves 4 before 2 and 3. To sum up, the new sequence is 1 -> 4 -> 2 -> 3. ## How was this patch tested? Manual. The following test case makes OOM intentionally to cause leaked filesystem connection in the current code base. With this patch, leakage doesn't occurs. ```scala // This should be tested manually because it raises OOM intentionally // in order to cause `Leaked filesystem connection`. test("SPARK-23399 Register a task completion listener first for OrcColumnarBatchReader") { withSQLConf(SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("orc").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("orc").save(new Path(basePath, "second").toString) val df = spark.read.orc( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20590 from dongjoon-hyun/SPARK-23399. --- .../sql/execution/datasources/orc/OrcFileFormat.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index dbf3bc6f0ee6c..1de2ca2914c44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -188,6 +188,12 @@ class OrcFileFormat if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, @@ -196,8 +202,6 @@ class OrcFileFormat partitionSchema, file.partitionValues) - val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) iter.asInstanceOf[Iterator[InternalRow]] } else { val orcRecordReader = new OrcInputFormat[OrcStruct] From 140f87533a468b1046504fc3ff01fbe1637e41cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Feb 2018 06:45:54 -0800 Subject: [PATCH 043/593] [SPARK-23394][UI] In RDD storage page show the executor addresses instead of the IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending RDD storage page to show executor addresses in the block table. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 10 30 59](https://user-images.githubusercontent.com/2017933/36142668-0b3578f8-10a9-11e8-95ea-2f57703ee4af.png) Author: “attilapiros” Closes #20589 from attilapiros/SPARK-23394. --- .../org/apache/spark/ui/storage/RDDPage.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 02cee7f8c5b33..2674b9291203a 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.status.api.v1.{ExecutorSummary, RDDDataDistribution, RDDPartitionInfo} import org.apache.spark.ui._ import org.apache.spark.util.Utils @@ -76,7 +76,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, - blockSortDesc) + blockSortDesc, + store.executorList(true)) _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -182,7 +183,8 @@ private[ui] class BlockDataSource( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + desc: Boolean, + executorIdToAddress: Map[String, String]) extends PagedDataSource[BlockTableRowData](pageSize) { private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) @@ -198,7 +200,10 @@ private[ui] class BlockDataSource( rddPartition.storageLevel, rddPartition.memoryUsed, rddPartition.diskUsed, - rddPartition.executors.mkString(" ")) + rddPartition.executors + .map { id => executorIdToAddress.get(id).getOrElse(id) } + .sorted + .mkString(" ")) } /** @@ -226,7 +231,8 @@ private[ui] class BlockPagedTable( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[BlockTableRowData] { + desc: Boolean, + executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] { override def tableId: String = "rdd-storage-by-block-table" @@ -243,7 +249,8 @@ private[ui] class BlockPagedTable( rddPartitions, pageSize, sortColumn, - desc) + desc, + executorSummaries.map { ex => (ex.id, ex.hostPort) }.toMap) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") From 400a1d9e25c1196f0be87323bd89fb3af0660166 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 10:57:12 -0800 Subject: [PATCH 044/593] Revert "[SPARK-23249][SQL] Improved block merging logic for partitions" This reverts commit 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9. --- .../sql/execution/DataSourceScanExec.scala | 29 +++++-------------- .../datasources/FileSourceStrategySuite.scala | 15 ++++++---- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index ba1157d5b6a49..08ff33afbba3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -444,29 +444,16 @@ case class FileSourceScanExec( currentSize = 0 } - def addFile(file: PartitionedFile): Unit = { - currentFiles += file - currentSize += file.length + openCostInBytes - } - - var frontIndex = 0 - var backIndex = splitFiles.length - 1 - - while (frontIndex <= backIndex) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - while (frontIndex <= backIndex && - currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - } - while (backIndex > frontIndex && - currentSize + splitFiles(backIndex).length <= maxSplitBytes) { - addFile(splitFiles(backIndex)) - backIndex -= 1 + // Assign files to partitions using "Next Fit Decreasing" + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() } - closePartition() + // Add the given file to the current partition. + currentSize += file.length + openCostInBytes + currentFiles += file } + closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index bfccc9335b361..c1d61b843d899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,17 +141,16 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] - assert(partitions.size == 3, "when checking partitions") - assert(partitions(0).files.size == 2, "when checking partition 1") + // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] + assert(partitions.size == 4, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") + assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1, file6) + // First partition reads (file1) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) - assert(partitions(0).files(1).start == 0) - assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -164,6 +163,10 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) + + // Final partition reads (file6) + assert(partitions(3).files(0).start == 0) + assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From 658d9d9d785a30857bf35d164e6cbbd9799d6959 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 14 Feb 2018 14:27:02 -0800 Subject: [PATCH 045/593] [SPARK-23406][SS] Enable stream-stream self-joins ## What changes were proposed in this pull request? Solved two bugs to enable stream-stream self joins. ### Incorrect analysis due to missing MultiInstanceRelation trait Streaming leaf nodes did not extend MultiInstanceRelation, which is necessary for the catalyst analyzer to convert the self-join logical plan DAG into a tree (by creating new instances of the leaf relations). This was causing the error `Failure when resolving conflicting references in Join:` (see JIRA for details). ### Incorrect attribute rewrite when splicing batch plans in MicroBatchExecution When splicing the source's batch plan into the streaming plan (by replacing the StreamingExecutionPlan), we were rewriting the attribute reference in the streaming plan with the new attribute references from the batch plan. This was incorrectly handling the scenario when multiple StreamingExecutionRelation point to the same source, and therefore eventually point to the same batch plan returned by the source. Here is an example query, and its corresponding plan transformations. ``` val df = input.toDF val join = df.select('value % 5 as "key", 'value).join( df.select('value % 5 as "key", 'value), "key") ``` Streaming logical plan before splicing the batch plan ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- StreamingExecutionRelation Memory[#1], value#12 // two different leaves pointing to same source ``` Batch logical plan after splicing the batch plan and before rewriting ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#12 ``` Batch logical plan after rewriting the attributes. Specifically, for spliced, the new output attributes (value#66) replace the earlier output attributes (value#12, and value#1, one for each StreamingExecutionRelation). ``` Project [key#6, value#66, value#66] // both value#1 and value#12 replaces by value#66 +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9, value#66] +- LocalRelation [value#66] ``` This causes the optimizer to eliminate value#66 from one side of the join. ``` Project [key#6, value#66, value#66] +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9] // this does not generate value, incorrect join results +- LocalRelation [value#66] ``` **Solution**: Instead of rewriting attributes, use a Project to introduce aliases between the output attribute references and the new reference generated by the spliced plans. The analyzer and optimizer will take care of the rest. ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- Project [value#66 AS value#1] // solution: project with aliases : +- LocalRelation [value#66] +- Project [(value#12 % 5) AS key#9, value#12] +- Project [value#66 AS value#12] // solution: project with aliases +- LocalRelation [value#66] ``` ## How was this patch tested? New unit test Author: Tathagata Das Closes #20598 from tdas/SPARK-23406. --- .../streaming/MicroBatchExecution.scala | 16 ++++++------ .../streaming/StreamingRelation.scala | 20 ++++++++++----- .../sql/streaming/StreamingJoinSuite.scala | 25 ++++++++++++++++++- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..ac73ba3417904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} @@ -415,8 +415,6 @@ class MicroBatchExecution( } } - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => @@ -424,18 +422,18 @@ class MicroBatchExecution( assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + s"${Utils.truncatedString(dataPlan.output, ",")}") - replacements ++= output.zip(dataPlan.output) - dataPlan + + val aliases = output.zip(dataPlan.output).map { case (to, from) => + Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) + } + Project(aliases, dataPlan) }.getOrElse { LocalRelation(output, isStreaming = true) } } // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) val newAttributePlan = newBatchesPlan transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 7146190645b37..f02d3a2c3733f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} @@ -42,7 +42,7 @@ object StreamingRelation { * passing to [[StreamExecution]] to run a query. */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName @@ -53,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) } /** @@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: case class StreamingExecutionRelation( source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -74,6 +76,8 @@ case class StreamingExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } // We have to pack in the V1 data source as a shim, for the case when a source implements @@ -92,13 +96,15 @@ case class StreamingRelationV2( extraOptions: Map[String, String], output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** @@ -108,7 +114,7 @@ case class ContinuousExecutionRelation( source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -120,6 +126,8 @@ case class ContinuousExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 54eb863dacc83..92087f68ad74a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ @@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } + test("stream stream self join") { + val input = MemoryStream[Int] + val df = input.toDF + val join = + df.select('value % 5 as "key", 'value).join( + df.select('value % 5 as "key", 'value), "key") + + testStream(join)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + } + test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ From a77ebb0921e390cf4fc6279a8c0a92868ad7e69b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:52:59 -0800 Subject: [PATCH 046/593] [SPARK-23421][SPARK-22356][SQL] Document the behavior change in ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/19579 introduces a behavior change. We need to document it in the migration guide. ## How was this patch tested? Also update the HiveExternalCatalogVersionsSuite to verify it. Author: gatorsmile Closes #20606 from gatorsmile/addMigrationGuide. --- docs/sql-programming-guide.md | 2 ++ .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0f9f01e18682f..cf9529a79f4f9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1963,6 +1963,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). ## Upgrading From Spark SQL 2.0 to 2.1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index ae4aeb7b4ce4a..c13a750dbb270 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") protected var spark: SparkSession = _ @@ -249,7 +249,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { spark.sql("msck repair table " + tbl_with_col_overlap) assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) From 95e4b4916065e66a4f8dba57e98e725796f75e04 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:56:02 -0800 Subject: [PATCH 047/593] [SPARK-23094] Revert [] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? This PR is to revert the PR https://github.com/apache/spark/pull/20302, because it causes a regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20614 from gatorsmile/revertJsonFix. --- .../catalyst/json/CreateJacksonParser.scala | 5 ++- .../sources/JsonHadoopFsRelationSuite.scala | 34 ------------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index b1672e7e2fca2..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,11 +40,10 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) - jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) + jsonFactory.createParser(record.getBytes, 0, record.getLength) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) + jsonFactory.createParser(record) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 27f398ebf301a..49be30435ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - private val badJson = "\u0000\u0000\u0000A\u0001AAA" - // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -107,36 +105,4 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } - - test("invalid json with leading nulls - from file (multiLine=true)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val expected = s"""$badJson\n{"a":1}\n""" - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) - checkAnswer(df, Row(null, expected)) - } - } - - test("invalid json with leading nulls - from file (multiLine=false)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) - checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) - } - } - - test("invalid json with leading nulls - from dataset") { - import testImplicits._ - checkAnswer( - spark.read.json(Seq(badJson).toDS()), - Row(badJson)) - } } From f38c760638063f1fb45e9ee2c772090fb203a4a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Feb 2018 16:59:44 +0800 Subject: [PATCH 048/593] [SPARK-23419][SPARK-23416][SS] data source v2 write path should re-throw interruption exceptions directly ## What changes were proposed in this pull request? Streaming execution has a list of exceptions that means interruption, and handle them specially. `WriteToDataSourceV2Exec` should also respect this list and not wrap them with `SparkException`. ## How was this patch tested? existing test. Author: Wenchen Fan Closes #20605 from cloud-fan/write. --- .../datasources/v2/WriteToDataSourceV2.scala | 11 ++++- .../execution/streaming/StreamExecution.scala | 40 ++++++++++--------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 535e7962d7439..41cdfc80d8a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.util.control.NonFatal + import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging @@ -27,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -107,7 +110,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e throw new SparkException("Writing job failed.", cause) } logError(s"Data source writer $writer aborted.") - throw new SparkException("Writing job aborted.", cause) + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } } sparkContext.emptyRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index e7982d7880ceb..3fc8c7887896a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -356,25 +356,7 @@ abstract class StreamExecution( private def isInterruptedByStop(e: Throwable): Boolean = { if (state.get == TERMINATED) { - e match { - // InterruptedIOException - thrown when an I/O operation is interrupted - // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted - case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => - true - // The cause of the following exceptions may be one of the above exceptions: - // - // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as - // BiFunction.apply - // ExecutionException - thrown by codes running in a thread pool and these codes throw an - // exception - // UncheckedExecutionException - thrown by codes that cannot throw a checked - // ExecutionException, such as BiFunction.apply - case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) - if e2.getCause != null => - isInterruptedByStop(e2.getCause) - case _ => - false - } + StreamExecution.isInterruptionException(e) } else { false } @@ -565,6 +547,26 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + + def isInterruptionException(e: Throwable): Boolean = e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptionException(e2.getCause) + case _ => + false + } } /** From 7539ae59d6c354c95c50528abe9ddff6972e960f Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 15 Feb 2018 17:09:06 +0800 Subject: [PATCH 049/593] [SPARK-23366] Improve hot reading path in ReadAheadInputStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `ReadAheadInputStream` was introduced in https://github.com/apache/spark/pull/18317/ to optimize reading spill files from disk. However, from the profiles it seems that the hot path of reading small amounts of data (like readInt) is inefficient - it involves taking locks, and multiple checks. Optimize locking: Lock is not needed when simply accessing the active buffer. Only lock when needing to swap buffers or trigger async reading, or get information about the async state. Optimize short-path single byte reads, that are used e.g. by Java library DataInputStream.readInt. The asyncReader used to call "read" only once on the underlying stream, that never filled the underlying buffer when it was wrapping an LZ4BlockInputStream. If the buffer was returned unfilled, that would trigger the async reader to be triggered to fill the read ahead buffer on each call, because the reader would see that the active buffer is below the refill threshold all the time. However, filling the full buffer all the time could introduce increased latency, so also add an `AtomicBoolean` flag for the async reader to return earlier if there is a reader waiting for data. Remove `readAheadThresholdInBytes` and instead immediately trigger async read when switching the buffers. It allows to simplify code paths, especially the hot one that then only has to check if there is available data in the active buffer, without worrying if it needs to retrigger async read. It seems to have positive effect on perf. ## How was this patch tested? It was noticed as a regression in some workloads after upgrading to Spark 2.3.  It was particularly visible on TPCDS Q95 running on instances with fast disk (i3 AWS instances). Running with profiling: * Spark 2.2 - 5.2-5.3 minutes 9.5% in LZ4BlockInputStream.read * Spark 2.3 - 6.4-6.6 minutes 31.1% in ReadAheadInputStream.read * Spark 2.3 + fix - 5.3-5.4 minutes 13.3% in ReadAheadInputStream.read - very slightly slower, practically within noise. We didn't see other regressions, and many workloads in general seem to be faster with Spark 2.3 (not investigated if thanks to async readed, or unrelated). Author: Juliusz Sompolski Closes #20555 from juliuszsompolski/SPARK-23366. --- .../apache/spark/io/ReadAheadInputStream.java | 119 +++++++++--------- .../unsafe/sort/UnsafeSorterSpillReader.java | 10 +- .../spark/io/GenericFileInputStreamSuite.java | 98 ++++++++------- .../spark/io/NioBufferedInputStreamSuite.java | 6 +- .../spark/io/ReadAheadInputStreamSuite.java | 17 ++- 5 files changed, 133 insertions(+), 117 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5b45d268ace8d..0cced9e222952 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream { // whether there is a read ahead task running, private boolean isReading; - // If the remaining data size in the current buffer is below this threshold, - // we issue an async read from the underlying input stream. - private final int readAheadThresholdInBytes; + // whether there is a reader waiting for data. + private AtomicBoolean isWaiting = new AtomicBoolean(false); private final InputStream underlyingInputStream; @@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream { * * @param inputStream The underlying input stream. * @param bufferSizeInBytes The buffer size. - * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead - * threshold, an async read is triggered. */ public ReadAheadInputStream( - InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + InputStream inputStream, int bufferSizeInBytes) { Preconditions.checkArgument(bufferSizeInBytes > 0, "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); - Preconditions.checkArgument(readAheadThresholdInBytes > 0 && - readAheadThresholdInBytes < bufferSizeInBytes, - "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " + - "but the value is " + readAheadThresholdInBytes); activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); - this.readAheadThresholdInBytes = readAheadThresholdInBytes; this.underlyingInputStream = inputStream; activeBuffer.flip(); readAheadBuffer.flip(); @@ -166,12 +159,17 @@ public void run() { // in that case the reader waits for this async read to complete. // So there is no race condition in both the situations. int read = 0; + int off = 0, len = arr.length; Throwable exception = null; try { - while (true) { - read = underlyingInputStream.read(arr); - if (0 != read) break; - } + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); } catch (Throwable ex) { exception = ex; if (ex instanceof Error) { @@ -181,13 +179,12 @@ public void run() { } } finally { stateChangeLock.lock(); + readAheadBuffer.limit(off); if (read < 0 || (exception instanceof EOFException)) { endOfStream = true; } else if (exception != null) { readAborted = true; readException = exception; - } else { - readAheadBuffer.limit(read); } readInProgress = false; signalAsyncReadComplete(); @@ -230,7 +227,10 @@ private void signalAsyncReadComplete() { private void waitForAsyncReadComplete() throws IOException { stateChangeLock.lock(); + isWaiting.set(true); try { + // There is only one reader, and one writer, so the writer should signal only once, + // but a while loop checking the wake up condition is still needed to avoid spurious wakeups. while (readInProgress) { asyncReadComplete.await(); } @@ -239,6 +239,7 @@ private void waitForAsyncReadComplete() throws IOException { iio.initCause(e); throw iio; } finally { + isWaiting.set(false); stateChangeLock.unlock(); } checkReadException(); @@ -246,8 +247,13 @@ private void waitForAsyncReadComplete() throws IOException { @Override public int read() throws IOException { - byte[] oneByteArray = oneByte.get(); - return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + if (activeBuffer.hasRemaining()) { + // short path - just get one byte. + return activeBuffer.get() & 0xFF; + } else { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } } @Override @@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException { if (len == 0) { return 0; } - stateChangeLock.lock(); - try { - return readInternal(b, offset, len); - } finally { - stateChangeLock.unlock(); - } - } - /** - * flip the active and read ahead buffer - */ - private void swapBuffers() { - ByteBuffer temp = activeBuffer; - activeBuffer = readAheadBuffer; - readAheadBuffer = temp; - } - - /** - * Internal read function which should be called only from read() api. The assumption is that - * the stateChangeLock is already acquired in the caller before calling this function. - */ - private int readInternal(byte[] b, int offset, int len) throws IOException { - assert (stateChangeLock.isLocked()); if (!activeBuffer.hasRemaining()) { - waitForAsyncReadComplete(); - if (readAheadBuffer.hasRemaining()) { - swapBuffers(); - } else { - // The first read or activeBuffer is skipped. - readAsync(); + // No remaining in active buffer - lock and switch to write ahead buffer. + stateChangeLock.lock(); + try { waitForAsyncReadComplete(); - if (isEndOfStream()) { - return -1; + if (!readAheadBuffer.hasRemaining()) { + // The first read. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } } + // Swap the newly read read ahead buffer in place of empty active buffer. swapBuffers(); + // After swapping buffers, trigger another async read for read ahead buffer. + readAsync(); + } finally { + stateChangeLock.unlock(); } - } else { - checkReadException(); } len = Math.min(len, activeBuffer.remaining()); activeBuffer.get(b, offset, len); - if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { - readAsync(); - } return len; } + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + @Override public int available() throws IOException { stateChangeLock.lock(); @@ -323,6 +318,11 @@ public long skip(long n) throws IOException { if (n <= 0L) { return 0L; } + if (n <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position((int) n + activeBuffer.position()); + return n; + } stateChangeLock.lock(); long skipped; try { @@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException { if (available() >= n) { // we can skip from the internal buffers int toSkip = (int) n; - if (toSkip <= activeBuffer.remaining()) { - // Only skipping from active buffer is sufficient - activeBuffer.position(toSkip + activeBuffer.position()); - if (activeBuffer.remaining() <= readAheadThresholdInBytes - && !readAheadBuffer.hasRemaining()) { - readAsync(); - } - return n; - } // We need to skip from both active buffer and read ahead buffer toSkip -= activeBuffer.remaining(); + assert(toSkip > 0); // skipping from activeBuffer already handled. activeBuffer.position(0); activeBuffer.flip(); readAheadBuffer.position(toSkip + readAheadBuffer.position()); swapBuffers(); + // Trigger async read to emptied read ahead buffer. readAsync(); return n; } else { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 2c53c8d809d2e..fb179d07edebc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,21 +72,15 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } - final double readAheadFraction = - SparkEnv.get() == null ? 0.5 : - SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf - // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { if (readAheadEnabled) { this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), - (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + (int) bufferSizeBytes); } else { this.in = serializerManager.wrapStream(blockId, bs); } diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 3440e1aea2f46..22db3592ecc96 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite { protected File inputFile; - protected InputStream inputStream; + protected InputStream[] inputStreams; @Before public void setUp() throws IOException { @@ -54,77 +54,91 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - for (int i = 0; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testReadMultipleBytes() throws IOException { - byte[] readBytes = new byte[8 * 1024]; - int i = 0; - while (i < randomBytes.length) { - int read = inputStream.read(readBytes, 0, 8 * 1024); - for (int j = 0; j < read; j++) { - assertEquals(randomBytes[i], readBytes[j]); - i++; + for (InputStream inputStream: inputStreams) { + byte[] readBytes = new byte[8 * 1024]; + int i = 0; + while (i < randomBytes.length) { + int read = inputStream.read(readBytes, 0, 8 * 1024); + for (int j = 0; j < read; j++) { + assertEquals(randomBytes[i], readBytes[j]); + i++; + } } } } @Test public void testBytesSkipped() throws IOException { - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - // Skipping negative bytes should essential be a no-op - assertEquals(0, inputStream.skip(-1)); - assertEquals(0, inputStream.skip(-1024)); - assertEquals(0, inputStream.skip(Long.MIN_VALUE)); - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + // Skipping negative bytes should essential be a no-op + assertEquals(0, inputStream.skip(-1)); + assertEquals(0, inputStream.skip(-1024)); + assertEquals(0, inputStream.skip(Long.MIN_VALUE)); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testSkipFromFileChannel() throws IOException { - // Since the buffer is smaller than the skipped bytes, this will guarantee - // we skip from underlying file channel. - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < 2048; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(256, inputStream.skip(256)); - assertEquals(256, inputStream.skip(256)); - assertEquals(512, inputStream.skip(512)); - for (int i = 3072; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + // Since the buffer is smaller than the skipped bytes, this will guarantee + // we skip from underlying file channel. + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < 2048; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(256, inputStream.skip(256)); + assertEquals(256, inputStream.skip(256)); + assertEquals(512, inputStream.skip(512)); + for (int i = 3072; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterEOF() throws IOException { - assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); - assertEquals(-1, inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); + assertEquals(-1, inputStream.read()); + } } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java index 211b33a1a9fb0..a320f8662f707 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -18,6 +18,7 @@ import org.junit.Before; +import java.io.InputStream; import java.io.IOException; /** @@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new NioBufferedFileInputStream(inputFile); + inputStreams = new InputStream[] { + new NioBufferedFileInputStream(inputFile), // default + new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer + }; } } diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java index 918ddc4517ec4..bfa1e0b908824 100644 --- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -19,16 +19,27 @@ import org.junit.Before; import java.io.IOException; +import java.io.InputStream; /** - * Tests functionality of {@link NioBufferedFileInputStream} + * Tests functionality of {@link ReadAheadInputStreamSuite} */ public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new ReadAheadInputStream( - new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + inputStreams = new InputStream[] { + // Tests equal and aligned buffers of wrapped an outer stream. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 1024), 8 * 1024), + // Tests aligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 1024), 2 * 1024), + // Tests aligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 1024), 3 * 1024), + // Tests unaligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 123), + // Tests unaligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 321) + }; } } From ed8647609883fcef16be5d24c2cb4ebda25bd6f0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Feb 2018 17:13:05 +0800 Subject: [PATCH 050/593] [SPARK-23359][SQL] Adds an alias 'names' of 'fieldNames' in Scala's StructType ## What changes were proposed in this pull request? This PR proposes to add an alias 'names' of 'fieldNames' in Scala. Please see the discussion in [SPARK-20090](https://issues.apache.org/jira/browse/SPARK-20090). ## How was this patch tested? Unit tests added in `DataTypeSuite.scala`. Author: hyukjinkwon Closes #20545 from HyukjinKwon/SPARK-23359. --- .../scala/org/apache/spark/sql/types/StructType.scala | 7 +++++++ .../scala/org/apache/spark/sql/types/DataTypeSuite.scala | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e3b0969283a84..d5011c3cb87e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -104,6 +104,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) + /** + * Returns all field names in an array. This is an alias of `fieldNames`. + * + * @since 2.4.0 + */ + def names: Array[String] = fieldNames + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 8e2b32c2b9a08..5a86f4055dce7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -134,6 +134,14 @@ class DataTypeSuite extends SparkFunSuite { assert(mapped === expected) } + test("fieldNames and names returns field names") { + val struct = StructType( + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) + + assert(struct.fieldNames === Seq("a", "b")) + assert(struct.names === Seq("a", "b")) + } + test("merge where right contains type conflict") { val left = StructType( StructField("a", LongType) :: From 44e20c42254bc6591b594f54cd94ced5fcfadae3 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 15 Feb 2018 03:52:40 -0800 Subject: [PATCH 051/593] =?UTF-8?q?[SPARK-23422][CORE]=20YarnShuffleIntegr?= =?UTF-8?q?ationSuite=20fix=20when=20SPARK=5FPREPEN=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …D_CLASSES set to 1 ## What changes were proposed in this pull request? YarnShuffleIntegrationSuite fails when SPARK_PREPEND_CLASSES set to 1. Normally mllib built before yarn module. When SPARK_PREPEND_CLASSES used mllib classes are on yarn test classpath. Before 2.3 that did not cause issues. But 2.3 has SPARK-22450, which registered some mllib classes with the kryo serializer. Now it dies with the following error: ` 18/02/13 07:33:29 INFO SparkContext: Starting job: collect at YarnShuffleIntegrationSuite.scala:143 Exception in thread "dag-scheduler-event-loop" java.lang.NoClassDefFoundError: breeze/linalg/DenseMatrix ` In this PR NoClassDefFoundError caught only in case of testing and then do nothing. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20608 from gaborgsomogyi/SPARK-23422. --- .../main/scala/org/apache/spark/serializer/KryoSerializer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 538ae05e4eea1..72427dd6ce4d4 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -206,6 +206,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(clazz) } catch { case NonFatal(_) => // do nothing + case _: NoClassDefFoundError if Utils.isTesting => // See SPARK-23422. } } From f217d7d9b22c4b9c947fc5467379af17f036ee61 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Feb 2018 07:47:40 -0800 Subject: [PATCH 052/593] [INFRA] Close stale PRs. Closes #20587 Closes #20586 From 2f0498d1e85a53b60da6a47d20bbdf56b42b7dcb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 08:55:39 -0800 Subject: [PATCH 053/593] [SPARK-23426][SQL] Use `hive` ORC impl and disable PPD for Spark 2.3.0 ## What changes were proposed in this pull request? To prevent any regressions, this PR changes ORC implementation to `hive` by default like Spark 2.2.X. Users can enable `native` ORC. Also, ORC PPD is also restored to `false` like Spark 2.2.X. ![orc_section](https://user-images.githubusercontent.com/9700541/36221575-57a1d702-1173-11e8-89fe-dca5842f4ca7.png) ## How was this patch tested? Pass all test cases. Author: Dongjoon Hyun Closes #20610 from dongjoon-hyun/SPARK-ORC-DISABLE. --- docs/sql-programming-guide.md | 52 ++++++++----------- .../apache/spark/sql/internal/SQLConf.scala | 6 +-- .../spark/sql/FileBasedDataSourceSuite.scala | 17 +++++- .../sql/streaming/FileStreamSinkSuite.scala | 13 +++++ .../sql/streaming/FileStreamSourceSuite.scala | 13 +++++ 5 files changed, 68 insertions(+), 33 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cf9529a79f4f9..91e43678481d6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1004,6 +1004,29 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession +## ORC Files + +Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. +To do that, the following configurations are newly added. The vectorized reader is used for the +native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` +is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC +serde tables (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), +the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also set to `true`. + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.orc.implhiveThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    + ## JSON Datasets
    @@ -1776,35 +1799,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 - - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. - - - New configurations - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    - - - Changed configurations - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
    - - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7835dbaa58439..f24fd7ff74d3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default prior to Spark 2.3.") + "1.2.1. It is 'hive' by default.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("native") + .createWithDefault("hive") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 2e332362ea644..b5d4c558f0d3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -20,14 +20,29 @@ package org.apache.spark.sql import java.io.FileNotFoundException import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") private val nameWithSpecialChars = "sp&cial%c hars" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 8c4e1fd00b0a2..ba48bc1ce0c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -33,6 +33,19 @@ import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + test("unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5bb0f4d643bbe..d4bd9c7987f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -207,6 +207,19 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + // ============= Basic parameter exists tests ================ test("FileStreamSource schema: no path") { From 6968c3cfd70961c4e86daffd6a156d0a9c1d7a2a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 09:40:08 -0800 Subject: [PATCH 054/593] [MINOR][SQL] Fix an error message about inserting into bucketed tables ## What changes were proposed in this pull request? This replaces `Sparkcurrently` to `Spark currently` in the following error message. ```scala scala> sql("insert into t2 select * from v1") org.apache.spark.sql.AnalysisException: Output Hive table `default`.`t2` is bucketed but Sparkcurrently does NOT populate bucketed ... ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20617 from dongjoon-hyun/SPARK-ERROR-MSG. --- .../apache/spark/sql/hive/execution/InsertIntoHiveTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6fc..02a60f16b3b3a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -172,7 +172,7 @@ case class InsertIntoHiveTable( val enforceBucketingConfig = "hive.enforce.bucketing" val enforceSortingConfig = "hive.enforce.sorting" - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + + val message = s"Output Hive table ${table.identifier} is bucketed but Spark " + "currently does NOT populate bucketed output which is compatible with Hive." if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || From db45daab90ede4c03c1abc9096f4eac584e9db17 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Feb 2018 09:54:39 -0800 Subject: [PATCH 055/593] [SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug ## What changes were proposed in this pull request? #### Problem: Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown. However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed. #### Fix: This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later. Please see the discussion in the JIRA. Note: The multi-column `QuantileDiscretizer` also has the same issue. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20594 from viirya/SPARK-23377-2. --- .../apache/spark/ml/feature/Bucketizer.scala | 28 +++++++++++++++++++ .../ml/feature/QuantileDiscretizer.scala | 28 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 12 ++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 14 ++++++++-- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index c13bf47eacb94..f49c410cbcfe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -19,6 +19,10 @@ package org.apache.spark.ml.feature import java.{util => ju} +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Model @@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } + + override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + + private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1ec5f8cb6139b..3b4c25478fb1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.feature +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) + + override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + private[QuantileDiscretizer] + class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 7403680ae3fdc..41cf72fe3470a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setSplits(Array(0.1, 0.8, 0.9)) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) } test("Bucket numeric features") { @@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) + assert(t.hasDefault(t.outputCol)) + assert(bucketizer.hasDefault(bucketizer.outputCol)) } test("Bucketizer in a pipeline") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index e9a75e931e6a8..6c363799dd300 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("Test observed number of buckets and their sizes match expected values") { val spark = this.spark import spark.implicits._ @@ -132,7 +134,10 @@ class QuantileDiscretizerSuite .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setNumBuckets(6) - testDefaultReadWrite(t) + + val readDiscretizer = testDefaultReadWrite(t) + val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol") + readDiscretizer.fit(data) } test("Verify resulting model has parent") { @@ -379,7 +384,12 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBucketsArray(Array(5, 10)) - testDefaultReadWrite(discretizer) + + val readDiscretizer = testDefaultReadWrite(discretizer) + val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2") + readDiscretizer.fit(data) + assert(discretizer.hasDefault(discretizer.outputCol)) + assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } test("Multiple Columns: Both inputCol and inputCols are set") { From 1dc2c1d5e85c5f404f470aeb44c1f3c22786bdea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 15 Feb 2018 13:51:24 -0600 Subject: [PATCH 056/593] [SPARK-23413][UI] Fix sorting tasks by Host / Executor ID at the Stage page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fixing exception got at sorting tasks by Host / Executor ID: ``` java.lang.IllegalArgumentException: Invalid sort column: Host at org.apache.spark.ui.jobs.ApiHelper$.indexName(StagePage.scala:1017) at org.apache.spark.ui.jobs.TaskDataSource.sliceData(StagePage.scala:694) at org.apache.spark.ui.PagedDataSource.pageData(PagedTable.scala:61) at org.apache.spark.ui.PagedTable$class.table(PagedTable.scala:96) at org.apache.spark.ui.jobs.TaskPagedTable.table(StagePage.scala:708) at org.apache.spark.ui.jobs.StagePage.liftedTree1$1(StagePage.scala:293) at org.apache.spark.ui.jobs.StagePage.render(StagePage.scala:282) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) ``` Moreover some refactoring to avoid similar problems by introducing constants for each header name and reusing them at the identification of the corresponding sorting index. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 18 57 10](https://user-images.githubusercontent.com/2017933/36166532-1cfdf3b8-10f3-11e8-8d32-5fcaad2af214.png) Author: “attilapiros” Closes #20601 from attilapiros/SPARK-23413. --- .../org/apache/spark/status/storeTypes.scala | 2 + .../org/apache/spark/ui/jobs/StagePage.scala | 121 +++++++++++------- .../org/apache/spark/ui/StagePageSuite.scala | 63 ++++++++- 3 files changed, 139 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 412644d3657b5..646cf25880e37 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -109,6 +109,7 @@ private[spark] object TaskIndexNames { final val DURATION = "dur" final val ERROR = "err" final val EXECUTOR = "exe" + final val HOST = "hst" final val EXEC_CPU_TIME = "ect" final val EXEC_RUN_TIME = "ert" final val GC_TIME = "gc" @@ -165,6 +166,7 @@ private[spark] class TaskDataWrapper( val duration: Long, @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) val executorId: String, + @KVIndexParam(value = TaskIndexNames.HOST, parent = TaskIndexNames.STAGE) val host: String, @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) val status: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5c2b0c3a19996..a9265d4dbcdfb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -750,37 +750,39 @@ private[ui] class TaskPagedTable( } def headers: Seq[Node] = { + import ApiHelper._ + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), - ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + (HEADER_TASK_INDEX, ""), (HEADER_ID, ""), (HEADER_ATTEMPT, ""), (HEADER_STATUS, ""), + (HEADER_LOCALITY, ""), (HEADER_EXECUTOR, ""), (HEADER_HOST, ""), (HEADER_LAUNCH_TIME, ""), + (HEADER_DURATION, ""), (HEADER_SCHEDULER_DELAY, TaskDetailsClassNames.SCHEDULER_DELAY), + (HEADER_DESER_TIME, TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + (HEADER_GC_TIME, ""), + (HEADER_SER_TIME, TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + (HEADER_GETTING_RESULT_TIME, TaskDetailsClassNames.GETTING_RESULT_TIME), + (HEADER_PEAK_MEM, TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ + {if (hasAccumulators(stage)) Seq((HEADER_ACCUMULATORS, "")) else Nil} ++ + {if (hasInput(stage)) Seq((HEADER_INPUT_SIZE, "")) else Nil} ++ + {if (hasOutput(stage)) Seq((HEADER_OUTPUT_SIZE, "")) else Nil} ++ {if (hasShuffleRead(stage)) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + Seq((HEADER_SHUFFLE_READ_TIME, TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + (HEADER_SHUFFLE_TOTAL_READS, ""), + (HEADER_SHUFFLE_REMOTE_READS, TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ {if (hasShuffleWrite(stage)) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + Seq((HEADER_SHUFFLE_WRITE_TIME, ""), (HEADER_SHUFFLE_WRITE_SIZE, "")) } else { Nil }} ++ {if (hasBytesSpilled(stage)) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + Seq((HEADER_MEM_SPILL, ""), (HEADER_DISK_SPILL, "")) } else { Nil }} ++ - Seq(("Errors", "")) + Seq((HEADER_ERROR, "")) if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { throw new IllegalArgumentException(s"Unknown column: $sortColumn") @@ -961,35 +963,62 @@ private[ui] class TaskPagedTable( } } -private object ApiHelper { - - - private val COLUMN_TO_INDEX = Map( - "ID" -> null.asInstanceOf[String], - "Index" -> TaskIndexNames.TASK_INDEX, - "Attempt" -> TaskIndexNames.ATTEMPT, - "Status" -> TaskIndexNames.STATUS, - "Locality Level" -> TaskIndexNames.LOCALITY, - "Executor ID / Host" -> TaskIndexNames.EXECUTOR, - "Launch Time" -> TaskIndexNames.LAUNCH_TIME, - "Duration" -> TaskIndexNames.DURATION, - "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, - "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, - "GC Time" -> TaskIndexNames.GC_TIME, - "Result Serialization Time" -> TaskIndexNames.SER_TIME, - "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, - "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, - "Accumulators" -> TaskIndexNames.ACCUMULATORS, - "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, - "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, - "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, - "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, - "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, - "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, - "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, - "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, - "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, - "Errors" -> TaskIndexNames.ERROR) +private[ui] object ApiHelper { + + val HEADER_ID = "ID" + val HEADER_TASK_INDEX = "Index" + val HEADER_ATTEMPT = "Attempt" + val HEADER_STATUS = "Status" + val HEADER_LOCALITY = "Locality Level" + val HEADER_EXECUTOR = "Executor ID" + val HEADER_HOST = "Host" + val HEADER_LAUNCH_TIME = "Launch Time" + val HEADER_DURATION = "Duration" + val HEADER_SCHEDULER_DELAY = "Scheduler Delay" + val HEADER_DESER_TIME = "Task Deserialization Time" + val HEADER_GC_TIME = "GC Time" + val HEADER_SER_TIME = "Result Serialization Time" + val HEADER_GETTING_RESULT_TIME = "Getting Result Time" + val HEADER_PEAK_MEM = "Peak Execution Memory" + val HEADER_ACCUMULATORS = "Accumulators" + val HEADER_INPUT_SIZE = "Input Size / Records" + val HEADER_OUTPUT_SIZE = "Output Size / Records" + val HEADER_SHUFFLE_READ_TIME = "Shuffle Read Blocked Time" + val HEADER_SHUFFLE_TOTAL_READS = "Shuffle Read Size / Records" + val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads" + val HEADER_SHUFFLE_WRITE_TIME = "Write Time" + val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records" + val HEADER_MEM_SPILL = "Shuffle Spill (Memory)" + val HEADER_DISK_SPILL = "Shuffle Spill (Disk)" + val HEADER_ERROR = "Errors" + + private[ui] val COLUMN_TO_INDEX = Map( + HEADER_ID -> null.asInstanceOf[String], + HEADER_TASK_INDEX -> TaskIndexNames.TASK_INDEX, + HEADER_ATTEMPT -> TaskIndexNames.ATTEMPT, + HEADER_STATUS -> TaskIndexNames.STATUS, + HEADER_LOCALITY -> TaskIndexNames.LOCALITY, + HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, + HEADER_HOST -> TaskIndexNames.HOST, + HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, + HEADER_DURATION -> TaskIndexNames.DURATION, + HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, + HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, + HEADER_GC_TIME -> TaskIndexNames.GC_TIME, + HEADER_SER_TIME -> TaskIndexNames.SER_TIME, + HEADER_GETTING_RESULT_TIME -> TaskIndexNames.GETTING_RESULT_TIME, + HEADER_PEAK_MEM -> TaskIndexNames.PEAK_MEM, + HEADER_ACCUMULATORS -> TaskIndexNames.ACCUMULATORS, + HEADER_INPUT_SIZE -> TaskIndexNames.INPUT_SIZE, + HEADER_OUTPUT_SIZE -> TaskIndexNames.OUTPUT_SIZE, + HEADER_SHUFFLE_READ_TIME -> TaskIndexNames.SHUFFLE_READ_TIME, + HEADER_SHUFFLE_TOTAL_READS -> TaskIndexNames.SHUFFLE_TOTAL_READS, + HEADER_SHUFFLE_REMOTE_READS -> TaskIndexNames.SHUFFLE_REMOTE_READS, + HEADER_SHUFFLE_WRITE_TIME -> TaskIndexNames.SHUFFLE_WRITE_TIME, + HEADER_SHUFFLE_WRITE_SIZE -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + HEADER_MEM_SPILL -> TaskIndexNames.MEM_SPILL, + HEADER_DISK_SPILL -> TaskIndexNames.DISK_SPILL, + HEADER_ERROR -> TaskIndexNames.ERROR) def hasAccumulators(stageData: StageData): Boolean = { stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 0aeddf730cd35..6044563f7dde7 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,13 +28,74 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} import org.apache.spark.status.config._ -import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable} class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 + test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + val statusStore = AppStatusStore.createLiveStore(conf) + try { + val stageData = new StageData( + status = StageStatus.ACTIVE, + stageId = 1, + attemptId = 1, + numTasks = 1, + numActiveTasks = 1, + numCompleteTasks = 1, + numFailedTasks = 1, + numKilledTasks = 1, + numCompletedIndices = 1, + + executorRunTime = 1L, + executorCpuTime = 1L, + submissionTime = None, + firstTaskLaunchedTime = None, + completionTime = None, + failureReason = None, + + inputBytes = 1L, + inputRecords = 1L, + outputBytes = 1L, + outputRecords = 1L, + shuffleReadBytes = 1L, + shuffleReadRecords = 1L, + shuffleWriteBytes = 1L, + shuffleWriteRecords = 1L, + memoryBytesSpilled = 1L, + diskBytesSpilled = 1L, + + name = "stage1", + description = Some("description"), + details = "detail", + schedulingPool = "pool1", + + rddIds = Seq(1), + accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), + tasks = None, + executorSummary = None, + killedTasksSummary = Map.empty + ) + val taskTable = new TaskPagedTable( + stageData, + basePath = "/a/b/c", + currentTime = 0, + pageSize = 10, + sortColumn = "Index", + desc = false, + store = statusStore + ) + val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet + assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet) + } finally { + statusStore.close() + } + } + test("peak execution memory should displayed") { val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" From c5857e496ff0d170ed0339f14afc7d36b192da6d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Feb 2018 09:41:17 -0800 Subject: [PATCH 057/593] [SPARK-23446][PYTHON] Explicitly check supported types in toPandas ## What changes were proposed in this pull request? This PR explicitly specifies and checks the types we supported in `toPandas`. This was a hole. For example, we haven't finished the binary type support in Python side yet but now it allows as below: ```python spark.conf.set("spark.sql.execution.arrow.enabled", "false") df = spark.createDataFrame([[bytearray("a")]]) df.toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", "true") df.toPandas() ``` ``` _1 0 [97] _1 0 a ``` This should be disallowed. I think the same things also apply to nested timestamps too. I also added some nicer message about `spark.sql.execution.arrow.enabled` in the error message. ## How was this patch tested? Manually tested and tests added in `python/pyspark/sql/tests.py`. Author: hyukjinkwon Closes #20625 from HyukjinKwon/pandas_convertion_supported_type. --- python/pyspark/sql/dataframe.py | 15 +++++++++------ python/pyspark/sql/tests.py | 9 ++++++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cc8b63cdfadf..f37777e13ee12 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1988,10 +1988,11 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps + _check_dataframe_localize_timestamps, to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version - import pyarrow require_minimum_pyarrow_version() + import pyarrow + to_arrow_schema(self.schema) tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) @@ -2000,10 +2001,12 @@ def toPandas(self): return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (_exception_message(e), msg)) + except Exception as e: + msg = ( + "Note: toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " + "to disable this.") + raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2af218a691026..19653072ea316 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3497,7 +3497,14 @@ def test_unsupported_datatype(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() + + df = self.spark.createDataFrame([(None,)], schema="a binary") + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): df.toPandas() def test_null_conversion(self): From 0a73aa31f41c83503d5d99eff3c9d7b406014ab3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Feb 2018 14:30:19 -0800 Subject: [PATCH 058/593] [SPARK-23362][SS] Migrate Kafka Microbatch source to v2 ## What changes were proposed in this pull request? Migrating KafkaSource (with data source v1) to KafkaMicroBatchReader (with data source v2). Performance comparison: In a unit test with in-process Kafka broker, I tested the read throughput of V1 and V2 using 20M records in a single partition. They were comparable. ## How was this patch tested? Existing tests, few modified to be better tests than the existing ones. Author: Tathagata Das Closes #20554 from tdas/SPARK-23362. --- dev/.rat-excludes | 1 + .../sql/kafka010/CachedKafkaConsumer.scala | 2 +- .../sql/kafka010/KafkaContinuousReader.scala | 29 +- .../sql/kafka010/KafkaMicroBatchReader.scala | 403 ++++++++++++++++++ .../KafkaRecordToUnsafeRowConverter.scala | 52 +++ .../spark/sql/kafka010/KafkaSource.scala | 19 +- .../sql/kafka010/KafkaSourceProvider.scala | 70 ++- ...a-source-initial-offset-future-version.bin | 2 + ...ka-source-initial-offset-version-2.1.0.bin | 2 +- ...scala => KafkaMicroBatchSourceSuite.scala} | 254 +++++++---- .../apache/spark/sql/internal/SQLConf.scala | 15 +- .../streaming/MicroBatchExecution.scala | 20 +- 12 files changed, 741 insertions(+), 128 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala create mode 100644 external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin rename external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/{KafkaSourceSuite.scala => KafkaMicroBatchSourceSuite.scala} (85%) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 243fbe3e1bc24..9552d001a079c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -105,3 +105,4 @@ META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin +kafka-source-initial-offset-future-version.bin diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 90ed7b1fba2f8..e97881cb0a163 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index b049a054cb40e..97a0f66e1880d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType @@ -187,13 +187,9 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val consumer = + CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ @@ -232,22 +228,7 @@ class KafkaContinuousDataReader( } override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + converter.toUnsafeRow(currentRecord) } override def getOffset(): KafkaSourcePartitionOffset = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala new file mode 100644 index 0000000000000..fb647ca7e70dd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[MicroBatchReader]] that reads data from Kafka. + * + * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using Kafka maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] class KafkaMicroBatchReader( + kafkaOffsetReader: KafkaOffsetReader, + executorKafkaParams: ju.Map[String, Object], + options: DataSourceOptions, + metadataPath: String, + startingOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + + type PartitionOffsetMap = Map[TopicPartition, Long] + + private var startPartitionOffsets: PartitionOffsetMap = _ + private var endPartitionOffsets: PartitionOffsetMap = _ + + private val pollTimeoutMs = options.getLong( + "kafkaConsumer.pollTimeoutMs", + SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + + private val maxOffsetsPerTrigger = + Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() + + override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + startPartitionOffsets = Option(start.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse(initialPartitionOffsets) + + endPartitionOffsets = Option(end.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse { + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + } + } + } + + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + // Find the new partitions, and get their earliest offsets + val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) + val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + // Find deleted partitions, and report data loss if required + val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = endPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList() + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val factories = topicPartitions.flatMap { tp => + val fromOffset = startPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse( + tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = endPartitionOffsets(tp) + + if (untilOffset >= fromOffset) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + val preferredLoc = if (numExecutors > 0) { + Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) + } else None + val range = KafkaOffsetRange(tp, fromOffset, untilOffset) + Some( + new KafkaMicroBatchDataReaderFactory( + range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) + } else { + reportDataLoss( + s"Partition $tp's offset was changed from " + + s"$fromOffset to $untilOffset, some data may have been missed") + None + } + } + factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + } + + override def getStartOffset: Offset = { + KafkaSourceOffset(startPartitionOffsets) + } + + override def getEndOffset: Offset = { + KafkaSourceOffset(endPartitionOffsets) + } + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = { + kafkaOffsetReader.close() + } + + override def toString(): String = s"Kafka[$kafkaOffsetReader]" + + /** + * Read initial partition offsets from the checkpoint, or decide the offsets and write them to + * the checkpoint. + */ + private def getOrCreateInitialPartitionOffsets(): PartitionOffsetMap = { + // Make sure that `KafkaConsumer.poll` is only called in StreamExecutionThread. + // Otherwise, interrupting a thread while running `KafkaConsumer.poll` may hang forever + // (KAFKA-1894). + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + + // SparkSession is required for getting Hadoop configuration for writing to checkpoints + assert(SparkSession.getActiveSession.nonEmpty) + + val metadataLog = + new KafkaSourceInitialOffsetWriter(SparkSession.getActiveSession.get, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = startingOffsets match { + case EarliestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => + kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: PartitionOffsetMap, + until: PartitionOffsetMap): PartitionOffsetMap = { + val fromNew = kafkaOffsetReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } + } + + private def getSortedExecutorList(): Array[String] = { + + def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + } + + val bm = SparkEnv.get.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } + + /** A version of [[HDFSMetadataLog]] specialized for saving the initial offsets. */ + class KafkaSourceInitialOffsetWriter(sparkSession: SparkSession, metadataPath: String) + extends HDFSMetadataLog[KafkaSourceOffset](sparkSession, metadataPath) { + + val VERSION = 1 + + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + in.read() // A zero byte is read to support Spark 2.1.0 (SPARK-19517) + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + // The log was generated by Spark 2.1.0 + KafkaSourceOffset(SerializedOffset(content)) + } + } + } +} + +/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReaderFactory( + range: KafkaOffsetRange, + preferredLoc: Option[String], + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + + override def preferredLocations(): Array[String] = preferredLoc.toArray + + override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) +} + +/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReader( + offsetRange: KafkaOffsetRange, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + private val rangeToRead = resolveRange(offsetRange) + private val converter = new KafkaRecordToUnsafeRowConverter + + private var nextOffset = rangeToRead.fromOffset + private var nextRow: UnsafeRow = _ + + override def next(): Boolean = { + if (nextOffset < rangeToRead.untilOffset) { + val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) + if (record != null) { + nextRow = converter.toUnsafeRow(record) + true + } else { + false + } + } else { + false + } + } + + override def get(): UnsafeRow = { + assert(nextRow != null) + nextOffset += 1 + nextRow + } + + override def close(): Unit = { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + + private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { + if (range.fromOffset < 0 || range.untilOffset < 0) { + // Late bind the offset range + val availableOffsetRange = consumer.getAvailableOffsetRange() + val fromOffset = if (range.fromOffset < 0) { + assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST, + s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}") + availableOffsetRange.earliest + } else { + range.fromOffset + } + val untilOffset = if (range.untilOffset < 0) { + assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST, + s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}") + availableOffsetRange.latest + } else { + range.untilOffset + } + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + } else { + range + } + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala new file mode 100644 index 0000000000000..1acdd56125741 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ +private[kafka010] class KafkaRecordToUnsafeRowConverter { + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { + bufferHolder.reset() + + if (record.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, record.key) + } + rowWriter.write(1, record.value) + rowWriter.write(2, UTF8String.fromString(record.topic)) + rowWriter.write(3, record.partition) + rowWriter.write(4, record.offset) + rowWriter.write( + 5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) + rowWriter.write(6, record.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 169a5d006fb04..1c7b3a29a861f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -306,7 +307,7 @@ private[kafka010] class KafkaSource( kafkaReader.close() } - override def toString(): String = s"KafkaSource[$kafkaReader]" + override def toString(): String = s"KafkaSourceV1[$kafkaReader]" /** * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. @@ -323,22 +324,6 @@ private[kafka010] class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you want your streaming query to fail on such cases, set the source - | option "failOnDataLoss" to "true". - """.stripMargin - - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you don't want your streaming query to fail on such cases, set the - | source option "failOnDataLoss" to "false". - """.stripMargin - private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d4fa0359c12d6..0aa64a6a9cf90 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,13 +30,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * The provider class for all Kafka readers and writers. It is designed such that it throws * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ @@ -47,6 +47,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with CreatableRelationProvider with StreamWriteSupport with ContinuousReadSupport + with MicroBatchReadSupport with Logging { import KafkaSourceProvider._ @@ -105,6 +106,52 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches + * of Kafka data in a micro-batch streaming query. + */ + override def createMicroBatchReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceOptions): KafkaMicroBatchReader = { + + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Kafka data in a continuous streaming query. + */ override def createContinuousReader( schema: Optional[StructType], metadataPath: String, @@ -408,8 +455,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + + + private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin new file mode 100644 index 0000000000000..d530773f57327 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin @@ -0,0 +1,2 @@ +0v99999 +{"kafka-initial-offset-future-version":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin index ae928e724967d..8c78d9e390a0e 100644 --- a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin @@ -1 +1 @@ -2{"kafka-initial-offset-2-1-0":{"2":0,"1":0,"0":0}} \ No newline at end of file +2{"kafka-initial-offset-2-1-0":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala similarity index 85% rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 02c87643568bd..ed4ecfeafa972 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -25,6 +25,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable +import scala.io.Source import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata @@ -42,7 +43,6 @@ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -112,14 +112,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + val sources = { + query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: KafkaSource, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) + }.distinct + if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -155,7 +159,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { +abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { import testImplicits._ @@ -303,94 +307,105 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { ) } - testWithUninterruptibleThread( - "deserialization of initial offset with Spark 2.1.0") { + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { withTempDir { metadataPath => - val topic = newTopic - testUtils.createTopic(topic, partitions = 3) + val topic = "kafka-initial-offset-current" + testUtils.createTopic(topic, partitions = 1) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - source.getOffset.get // Write initial offset - - // Make sure Spark 2.1.0 will throw an exception when reading the new log - intercept[java.lang.IllegalArgumentException] { - // Simulate how Spark 2.1.0 reads the log - Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) - } + val initialOffsetFile = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0").toFile + + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + + // Test the written initial offset file has 0 byte in the beginning, so that + // Spark 2.1.0 can read the offsets (see SPARK-19517) + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + makeSureGetOffsetCalled) + + val binarySource = Source.fromFile(initialOffsetFile) + try { + assert(binarySource.next().toInt == 0) // first byte is binary 0 + } finally { + binarySource.close() } } } - testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") { + test("deserialization of initial offset written by Spark 2.1.0 (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-2-1-0" testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, Array("0", "1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("0", "10", "20"), Some(1)) + testUtils.sendMessages(topic, Array("0", "100", "200"), Some(2)) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. val from = new File( getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath - val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) Files.copy(from, to) - val source = provider.createSource( - spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) - val deserializedOffset = source.getOffset.get - val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - assert(referenceOffset == deserializedOffset) + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + // Test that the query starts from the expected initial offset (i.e. read older offsets, + // even though startingOffsets is latest). + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + AddKafkaData(Set(topic), 1000), + CheckAnswer(0, 1, 2, 10, 20, 200, 1000)) } } - testWithUninterruptibleThread("deserialization of initial offset written by future version") { + test("deserialization of initial offset written by future version") { withTempDir { metadataPath => - val futureMetadataLog = - new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, - metadataPath.getAbsolutePath) { - override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { - out.write(0) - val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v99999\n${metadata.json}") - writer.flush - } - } - - val topic = newTopic + val topic = "kafka-initial-offset-future-version" testUtils.createTopic(topic, partitions = 3) - val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - futureMetadataLog.add(0, offset) - - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - val e = intercept[java.lang.IllegalStateException] { - source.getOffset.get // Read initial offset - } + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. + val from = new File( + getClass.getResource("/kafka-source-initial-offset-future-version.bin").toURI).toPath + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) + Files.copy(from, to) - Seq( - s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", - "produced by a newer version of Spark and cannot be read by this version" - ).foreach { message => - assert(e.getMessage.contains(message)) - } + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + ExpectFailure[IllegalStateException](e => { + Seq( + s"maximum supported log version is v1, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.toString.contains(message)) + } + })) } } @@ -542,6 +557,91 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { CheckLastBatch(120 to 124: _*) ) } + + test("ensure stream-stream self-join generates only one offset in offset log") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .load() + + val values = kafka + .selectExpr("CAST(CAST(value AS STRING) AS INT) AS value", + "CAST(CAST(value AS STRING) AS INT) % 5 AS key") + + val join = values.join(values, "key") + + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + ) + } +} + + +class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set( + "spark.sql.streaming.disabledV2MicroBatchReaders", + classOf[KafkaSourceProvider].getCanonicalName) + } + + test("V1 Source is used when disabled through SQLConf") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaSource, _) => true + }.nonEmpty + } + ) + } +} + +class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { + + test("V2 Source is used by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + }.nonEmpty + } + ) + } } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f24fd7ff74d3f..e75e1d66ebcf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1146,10 +1146,20 @@ object SQLConf { val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .internal() .doc("A comma-separated list of fully qualified data source register class names for which" + - " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") .stringConf .createWithDefault("") + val DISABLED_V2_STREAMING_MICROBATCH_READERS = + buildConf("spark.sql.streaming.disabledV2MicroBatchReaders") + .internal() + .doc( + "A comma-separated list of fully qualified data source register class names for which " + + "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " + + "V1 Sources.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1525,6 +1535,9 @@ class SQLConf extends Serializable with Logging { def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def disabledV2StreamingMicroBatchReaders: String = + getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ac73ba3417904..84655013ba957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -72,27 +72,36 @@ class MicroBatchExecution( // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, // since the existing logical plan has already used those attributes. The per-microbatch // transformation is responsible for replacing attributes with their final values. + + val disabledSources = + sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + val _logicalPlan = analyzedPlan.transform { - case streamingRelation@StreamingRelation(dataSource, _, output) => + case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) + val source = dataSourceV1.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + case s @ StreamingRelationV2( + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = source.createMicroBatchReader( + val reader = dataSourceV2.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + logInfo(s"Using MicroBatchReader [$reader] from " + + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => + case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" @@ -102,6 +111,7 @@ class MicroBatchExecution( } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(source, output)(sparkSession) }) } From d5ed2108d32e1d95b26ee7fed39e8a733e935e2c Mon Sep 17 00:00:00 2001 From: Shintaro Murakami Date: Fri, 16 Feb 2018 17:17:55 -0800 Subject: [PATCH 059/593] [SPARK-23381][CORE] Murmur3 hash generates a different value from other implementations ## What changes were proposed in this pull request? Murmur3 hash generates a different value from the original and other implementations (like Scala standard library and Guava or so) when the length of a bytes array is not multiple of 4. ## How was this patch tested? Added a unit test. **Note: When we merge this PR, please give all the credits to Shintaro Murakami.** Author: Shintaro Murakami Author: gatorsmile Author: Shintaro Murakami Closes #20630 from gatorsmile/pr-20568. --- .../spark/util/sketch/Murmur3_x86_32.java | 16 +++++++++ .../spark/unsafe/hash/Murmur3_x86_32.java | 16 +++++++++ .../unsafe/hash/Murmur3_x86_32Suite.java | 19 +++++++++++ .../spark/ml/feature/FeatureHasher.scala | 33 ++++++++++++++++++- .../spark/mllib/feature/HashingTF.scala | 2 +- .../spark/ml/feature/FeatureHasherSuite.scala | 11 ++++++- python/pyspark/ml/feature.py | 4 +-- 7 files changed, 96 insertions(+), 5 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java index a61ce4fb7241d..e83b331391e39 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 5e7ee480cafd1..d239de6083ad0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index e759cb33b3e6a..6348a73bf3895 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,6 +22,8 @@ import java.util.Random; import java.util.Set; +import scala.util.hashing.MurmurHash3$; + import org.apache.spark.unsafe.Platform; import org.junit.Assert; import org.junit.Test; @@ -51,6 +53,23 @@ public void testKnownLongInputs() { Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); } + // SPARK-23381 Check whether the hash of the byte array is the same as another implementations + @Test + public void testKnownBytesInputs() { + byte[] test = "test".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0), + Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0)); + byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0), + Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0)); + byte[] te = "te".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0), + Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0)); + byte[] tes = "tes".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0), + Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index a918dd4c075da..c78f61ac3ef71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - val hashFunc: Any => Int = OldHashingTF.murmur3Hash + val hashFunc: Any => Int = FeatureHasher.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) val catCols = if (isSet(categoricalCols)) { @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { @Since("2.3.0") override def load(path: String): FeatureHasher = super.load(path) + + private val seed = OldHashingTF.seed + + /** + * Calculate a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32). + * This is the default hash algorithm used from Spark 2.0 onwards. + * Use hashUnsafeBytes2 to match the original algorithm with the value. + * See SPARK-23381. + */ + @Since("2.3.0") + private[feature] def murmur3Hash(term: Any): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 9abdd44a635d1..8935c8496cdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -135,7 +135,7 @@ object HashingTF { private[HashingTF] val Murmur3: String = "murmur3" - private val seed = 42 + private[spark] val seed = 42 /** * Calculate a hash code value for the term object using the native Scala implementation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 3fc3cbb62d5b5..7bc1825b69c43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class FeatureHasherSuite extends SparkFunSuite with MLlibTestSparkContext @@ -34,7 +35,7 @@ class FeatureHasherSuite extends SparkFunSuite import testImplicits._ - import HashingTFSuite.murmur3FeatureIdx + import FeatureHasherSuite.murmur3FeatureIdx implicit private val vectorEncoder = ExpressionEncoder[Vector]() @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite testDefaultReadWrite(t) } } + +object FeatureHasherSuite { + + private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures) + } + +} diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..04b07e6a05481 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> df = spark.createDataFrame(data, cols) >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasher.setCategoricalCols(["real"]).transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) + SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) From 15ad4a7f1000c83cefbecd41e315c964caa3c39f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sat, 17 Feb 2018 10:54:14 +0800 Subject: [PATCH 060/593] [SPARK-23447][SQL] Cleanup codegen template for Literal ## What changes were proposed in this pull request? Cleaned up the codegen templates for `Literal`s, to make sure that the `ExprCode` returned from `Literal.doGenCode()` has: 1. an empty `code` field; 2. an `isNull` field of either literal `true` or `false`; 3. a `value` field that is just a simple literal/constant. Before this PR, there are a couple of paths that would return a non-trivial `code` and all of them are actually unnecessary. The `NaN` and `Infinity` constants for `double` and `float` can be accessed through constants directly available so there's no need to add a reference for them. Also took the opportunity to add a new util method for ease of creating `ExprCode` for inline-able non-null values. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #20626 from rednaxelafx/codegen-literal. --- .../expressions/codegen/CodeGenerator.scala | 6 +++ .../sql/catalyst/expressions/literals.scala | 51 ++++++++++--------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dcbb702893da..31ba29ae8d8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -58,6 +58,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} */ case class ExprCode(var code: String, var isNull: String, var value: String) +object ExprCode { + def forNonNullValue(value: String): ExprCode = { + ExprCode(code = "", isNull = "false", value = value) + } +} + /** * State used for subexpression elimination. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index cd176d941819f..c1e65e34c2ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -278,40 +278,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - // change the isNull and primitive to consts, to inline them if (value == null) { - ev.isNull = "true" - ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") + val defaultValueLiteral = ctx.defaultValue(javaType) match { + case "null" => s"(($javaType)null)" + case lit => lit + } + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) } else { - ev.isNull = "false" dataType match { case BooleanType | IntegerType | DateType => - ev.copy(code = "", value = value.toString) + ExprCode.forNonNullValue(value.toString) case FloatType => - val v = value.asInstanceOf[Float] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}f") + value.asInstanceOf[Float] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Float.NaN") + case Float.PositiveInfinity => + ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + case Float.NegativeInfinity => + ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}F") } case DoubleType => - val v = value.asInstanceOf[Double] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}D") + value.asInstanceOf[Double] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Double.NaN") + case Double.PositiveInfinity => + ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + case Double.NegativeInfinity => + ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}D") } case ByteType | ShortType => - ev.copy(code = "", value = s"($javaType)$value") + ExprCode.forNonNullValue(s"($javaType)$value") case TimestampType | LongType => - ev.copy(code = "", value = s"${value}L") + ExprCode.forNonNullValue(s"${value}L") case _ => - ev.copy(code = "", value = ctx.addReferenceObj("literal", value, - ctx.javaType(dataType))) + val constRef = ctx.addReferenceObj("literal", value, javaType) + ExprCode.forNonNullValue(constRef) } } } From 3ee3b2ae1ff8fbeb43a08becef43a9bd763b06bb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Feb 2018 00:25:36 -0800 Subject: [PATCH 061/593] [SPARK-23340][SQL] Upgrade Apache ORC to 1.4.3 ## What changes were proposed in this pull request? This PR updates Apache ORC dependencies to 1.4.3 released on February 9th. Apache ORC 1.4.2 release removes unnecessary dependencies and 1.4.3 has 5 more patches (https://s.apache.org/Fll8). Especially, the following ORC-285 is fixed at 1.4.3. ```scala scala> val df = Seq(Array.empty[Float]).toDF() scala> df.write.format("orc").save("/tmp/floatarray") scala> spark.read.orc("/tmp/floatarray") res1: org.apache.spark.sql.DataFrame = [value: array] scala> spark.read.orc("/tmp/floatarray").show() 18/02/12 22:09:10 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 1) java.io.IOException: Error reading file: file:/tmp/floatarray/part-00000-9c0b461b-4df1-4c23-aac1-3e4f349ac7d6-c000.snappy.orc at org.apache.orc.impl.RecordReaderImpl.nextBatch(RecordReaderImpl.java:1191) at org.apache.orc.mapreduce.OrcMapreduceRecordReader.ensureBatch(OrcMapreduceRecordReader.java:78) ... Caused by: java.io.EOFException: Read past EOF for compressed stream Stream for column 2 kind DATA position: 0 length: 0 range: 0 offset: 0 limit: 0 ``` ## How was this patch tested? Pass the Jenkins test. Author: Dongjoon Hyun Closes #20511 from dongjoon-hyun/SPARK-23340. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 6 +----- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ .../apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala | 10 ++++++++++ 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 99031384aa22e..ed310507d14ed 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cf8d2789b7ee9..04dec04796af4 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index de949b94d676c..ac30107066389 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.1 + 1.4.3 nohive 1.6.0 9.3.20.v20170531 @@ -1740,10 +1740,6 @@ org.apache.hive hive-storage-api - - io.airlift - slice - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 6f5f2fd795f74..523f7cf77e103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -160,6 +160,15 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + Seq(Seq(Array.empty[Float]).toDF(), Seq(Array.empty[Double]).toDF()).foreach { df => + withTempPath { path => + df.write.format("orc").save(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), df) + } + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 92b2f069cacd6..597b0f56a55e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -208,4 +208,14 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "false") { + withTable("spark_23340") { + sql("CREATE TABLE spark_23340(a array, b array) STORED AS ORC") + sql("INSERT INTO spark_23340 VALUES (array(), array())") + checkAnswer(spark.table("spark_23340"), Seq(Row(Array.empty[Float], Array.empty[Double]))) + } + } + } } From f5850e78924d03448ad243cdd32b24c3fe0ea8af Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 13:33:03 +0800 Subject: [PATCH 062/593] [SPARK-23457][SQL] Register task completion listeners first in ParquetFileFormat ## What changes were proposed in this pull request? ParquetFileFormat leaks opened files in some cases. This PR prevents that by registering task completion listers first before initialization. - [spark-branch-2.3-test-sbt-hadoop-2.7](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.3-test-sbt-hadoop-2.7/205/testReport/org.apache.spark.sql/FileBasedDataSourceSuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) - [spark-master-test-sbt-hadoop-2.6](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4228/testReport/junit/org.apache.spark.sql.execution.datasources.parquet/ParquetQuerySuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) ``` Caused by: sbt.ForkMain$ForkError: java.lang.Throwable: null at org.apache.spark.DebugFilesystem$.addOpenStream(DebugFilesystem.scala:36) at org.apache.spark.DebugFilesystem.open(DebugFilesystem.scala:70) at org.apache.hadoop.fs.FileSystem.open(FileSystem.java:769) at org.apache.parquet.hadoop.ParquetFileReader.(ParquetFileReader.java:538) at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.initialize(SpecificParquetRecordReaderBase.java:149) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.initialize(VectorizedParquetRecordReader.java:133) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReaderWithPartitionValues$1.apply(ParquetFileFormat.scala:400) at ``` ## How was this patch tested? Manual. The following test case generates the same leakage. ```scala test("SPARK-23457 Register task completion listeners first in ParquetFileFormat") { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("parquet").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("parquet").save(new Path(basePath, "second").toString) val df = spark.read.parquet( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20619 from dongjoon-hyun/SPARK-23390. --- .../parquet/ParquetFileFormat.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ba69f9a26c968..476bd02374364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -395,16 +395,21 @@ class ParquetFileFormat ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } val taskContext = Option(TaskContext.get()) - val parquetReader = if (enableVectorizedReader) { + if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) + val iter = new RecordReaderIterator(vectorizedReader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) if (returningBatch) { vectorizedReader.enableReturningBatches() } - vectorizedReader + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow @@ -414,18 +419,11 @@ class ParquetFileFormat } else { new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) } + val iter = new RecordReaderIterator(reader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) reader.initialize(split, hadoopAttemptContext) - reader - } - val iter = new RecordReaderIterator(parquetReader) - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) - - // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. - if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && - enableVectorizedReader) { - iter.asInstanceOf[Iterator[InternalRow]] - } else { val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) From 651b0277fe989119932d5ae1ef729c9768aa018d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 20 Feb 2018 13:56:38 +0800 Subject: [PATCH 063/593] [SPARK-23436][SQL] Infer partition as Date only if it can be casted to Date ## What changes were proposed in this pull request? Before the patch, Spark could infer as Date a partition value which cannot be casted to Date (this can happen when there are extra characters after a valid date, like `2018-02-15AAA`). When this happens and the input format has metadata which define the schema of the table, then `null` is returned as a value for the partition column, because the `cast` operator used in (`PartitioningAwareFileIndex.inferPartitioning`) is unable to convert the value. The PR checks in the partition inference that values can be casted to Date and Timestamp, in order to infer that datatype to them. ## How was this patch tested? added UT Author: Marco Gaido Closes #20621 from mgaido91/SPARK-23436. --- .../datasources/PartitioningUtils.scala | 40 ++++++++++++++----- .../ParquetPartitionDiscoverySuite.scala | 14 +++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 472bf82d3604d..379acb67f7c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -407,6 +407,34 @@ object PartitioningUtils { Literal(bigDecimal) } + val dateTry = Try { + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // DateType + DateTimeUtils.getThreadLocalDateFormat.parse(raw) + // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. + // This can happen since DateFormat.parse may not use the entire text of the given string: + // so if there are extra-characters after the date, it returns correctly. + // We need to check that we can cast the raw string since we later can use Cast to get + // the partition values with the right DataType (see + // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning) + val dateValue = Cast(Literal(raw), DateType).eval() + // Disallow DateType if the cast returned null + require(dateValue != null) + Literal.create(dateValue, DateType) + } + + val timestampTry = Try { + val unescapedRaw = unescapePathName(raw) + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // TimestampType + DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw) + // SPARK-23436: see comment for date + val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval() + // Disallow TimestampType if the cast returned null + require(timestampValue != null) + Literal.create(timestampValue, TimestampType) + } + if (typeInference) { // First tries integral types Try(Literal.create(Integer.parseInt(raw), IntegerType)) @@ -415,16 +443,8 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try( - Literal.create( - DateTimeUtils.getThreadLocalTimestampFormat(timeZone) - .parse(unescapePathName(raw)).getTime * 1000L, - TimestampType))) - .orElse(Try( - Literal.create( - DateTimeUtils.millisToDays( - DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), - DateType))) + .orElse(timestampTry) + .orElse(dateTry) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index d4902641e335f..edb3da904d10d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1120,4 +1120,18 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Row(3, BigDecimal("2" * 30)) :: Nil) } } + + test("SPARK-23436: invalid Dates should be inferred as String in partition inference") { + withTempPath { path => + val data = Seq(("1", "2018-01", "2018-01-01-04", "test")) + .toDF("id", "date_month", "date_hour", "data") + + data.write.partitionBy("date_month", "date_hour").parquet(path.getAbsolutePath) + val input = spark.read.parquet(path.getAbsolutePath).select("id", + "date_month", "date_hour", "data") + + assert(input.schema.sameType(input.schema)) + checkAnswer(input, data) + } + } } From aadf9535b4a11b42fd9d72f636576d2da0766199 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 20 Feb 2018 16:04:22 +0800 Subject: [PATCH 064/593] [SPARK-23203][SQL] DataSourceV2: Use immutable logical plans. ## What changes were proposed in this pull request? SPARK-23203: DataSourceV2 should use immutable catalyst trees instead of wrapping a mutable DataSourceV2Reader. This commit updates DataSourceV2Relation and consolidates much of the DataSourceV2 API requirements for the read path in it. Instead of wrapping a reader that changes, the relation lazily produces a reader from its configuration. This commit also updates the predicate and projection push-down. Instead of the implementation from SPARK-22197, this reuses the rule matching from the Hive and DataSource read paths (using `PhysicalOperation`) and copies most of the implementation of `SparkPlanner.pruneFilterProject`, with updates for DataSourceV2. By reusing the implementation from other read paths, this should have fewer regressions from other read paths and is less code to maintain. The new push-down rules also supports the following edge cases: * The output of DataSourceV2Relation should be what is returned by the reader, in case the reader can only partially satisfy the requested schema projection * The requested projection passed to the DataSourceV2Reader should include filter columns * The push-down rule may be run more than once if filters are not pushed through projections ## How was this patch tested? Existing push-down and read tests. Author: Ryan Blue Closes #20387 from rdblue/SPARK-22386-push-down-immutable-trees. --- .../kafka010/KafkaContinuousSourceSuite.scala | 19 +- .../sql/kafka010/KafkaContinuousTest.scala | 4 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 41 +--- .../datasources/v2/DataSourceV2Relation.scala | 212 +++++++++++++++++- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../v2/PushDownOperatorsToDataSource.scala | 159 ++++--------- .../continuous/ContinuousExecution.scala | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 6 +- 10 files changed, 269 insertions(+), 187 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..f679e9bfc0450 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..48ac3fc1e8f9d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index ed4ecfeafa972..89c9ef4cc73b5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} @@ -119,7 +119,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..4274f120a375a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -189,39 +189,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. - val reader = (ds, userSpecifiedSchema) match { - case (ds: ReadSupportWithSchema, Some(schema)) => - ds.createReader(schema, options) - - case (ds: ReadSupport, None) => - ds.createReader(options) - - case (ds: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $ds.") - - case (ds: ReadSupport, Some(schema)) => - val reader = ds.createReader(options) - if (reader.readSchema() != schema) { - throw new AnalysisException(s"$ds does not allow user-specified schemas.") - } - reader - - case _ => null // fall back to v1 - } + val ds = cls.newInstance().asInstanceOf[DataSourceV2] + if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( + ds, extraOptions.toMap ++ sessionOptions, + userSpecifiedSchema = userSpecifiedSchema)) - if (reader == null) { - loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + loadV1Source(paths: _*) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..a98dd4866f82a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,17 +17,80 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + options: Map[String, String], + projection: Seq[AttributeReference], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + + import DataSourceV2Relation._ + + override def simpleString: String = { + s"DataSourceV2Relation(source=${source.name}, " + + s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + + s"filters=[${pushedFilters.mkString(", ")}], options=$options)" + } + + override lazy val schema: StructType = reader.readSchema() + + override lazy val output: Seq[AttributeReference] = { + // use the projection attributes to avoid assigning new ids. fields that are not projected + // will be assigned new ids, which is okay because they are not projected. + val attrMap = projection.map(a => a.name -> a).toMap + schema.map(f => attrMap.getOrElse(f.name, + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + } + + private lazy val v2Options: DataSourceOptions = makeV2Options(options) + + lazy val ( + reader: DataSourceReader, + unsupportedFilters: Seq[Expression], + pushedFilters: Seq[Expression]) = { + val newReader = userSpecifiedSchema match { + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + + DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + val (remainingFilters, pushedFilters) = filters match { + case Some(filterSeq) => + DataSourceV2Relation.pushFilters(newReader, filterSeq) + case _ => + (Nil, Nil) + } + + (newReader, remainingFilters, pushedFilters) + } + + override def doCanonicalize(): LogicalPlan = { + val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + + // override output with canonicalized output to avoid attempting to configure a reader + val canonicalOutput: Seq[AttributeReference] = this.output + .map(a => QueryPlan.normalizeExprId(a, projection)) + + new DataSourceV2Relation(c.source, c.options, c.projection) { + override lazy val output: Seq[AttributeReference] = canonicalOutput + } + } override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => @@ -37,7 +100,9 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(output = output.map(_.newInstance())) + // projection is used to maintain id assignment. + // if projection is not set, use output so the copy is not equal to the original + copy(projection = projection.map(_.newInstance())) } } @@ -45,14 +110,137 @@ case class DataSourceV2Relation( * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical * to the non-streaming relation. */ -class StreamingDataSourceV2Relation( +case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + reader: DataSourceReader) + extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { override def isStreaming: Boolean = true + + override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } } object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + private implicit class SourceHelpers(source: DataSourceV2) { + def asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _: ReadSupportWithSchema => + // this method is only called if there is no user-supplied schema. if there is no + // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. + throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def asReadSupportWithSchema: ReadSupportWithSchema = { + source match { + case support: ReadSupportWithSchema => + support + case _: ReadSupport => + throw new AnalysisException( + s"Data source does not support user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def name: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + } + + private def makeV2Options(options: Map[String, String]): DataSourceOptions = { + new DataSourceOptions(options.asJava) + } + + private def schema( + source: DataSourceV2, + v2Options: DataSourceOptions, + userSchema: Option[StructType]): StructType = { + val reader = userSchema match { + // TODO: remove this case because it is confusing for users + case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => + val reader = source.asReadSupport.createReader(v2Options) + if (reader.readSchema() != s) { + throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") + } + reader + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + reader.readSchema() + } + + def create( + source: DataSourceV2, + options: Map[String, String], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, options, projection, filters, + // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must + // be equal to the reader's schema. the schema method enforces this. because the user schema + // and the reader's schema are identical, drop the user schema. + if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + } + + private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + reader match { + case projectionSupport: SupportsPushDownRequiredColumns => + projectionSupport.pruneColumns(struct) + case _ => + } + } + + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case catalystFilterSupport: SupportsPushDownCatalystFilters => + ( + catalystFilterSupport.pushCatalystFilters(filters.toArray), + catalystFilterSupport.pushedCatalystFilters() + ) + + case filterSupport: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet + val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => + unhandledFilters.contains(f) + } + + (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + + case _ => (filters, Nil) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..c4e7644683c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case relation: DataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + + case relation: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..f23d228567241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,130 +17,55 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} -import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.v2.reader._ -/** - * Pushes down various operators to the underlying data source for better performance. Operators are - * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you - * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the - * data source should execute FILTER before LIMIT. And required columns are calculated at the end, - * because when more operators are pushed down, we may need less columns at Spark side. - */ -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may - // appear in many places for column pruning. - // TODO: Ideally column pruning should be implemented via a plan property that is propagated - // top-down, then we can simplify the logic here and only collect target operators. - val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(candidates.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => candidates - } - - val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) - if (withFilter.output == fields) { - withFilter - } else { - Project(fields, withFilter) - } - } - - // TODO: add more push down rules. - - val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(columnPruned) - } - - // TODO: nested fields pruning - private def pushDownRequiredColumns( - plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { - plan match { - case p @ Project(projectList, child) => - val required = projectList.flatMap(_.references) - p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - - case f @ Filter(condition, child) => - val required = requiredByParent ++ condition.references - f.copy(child = pushDownRequiredColumns(child, required)) - - case relation: DataSourceV2Relation => relation.reader match { - case reader: SupportsPushDownRequiredColumns => - // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now - // it's possible that the mutable reader being updated by someone else, and we need to - // always call `reader.pruneColumns` here to correct it. - // assert(relation.output.toStructType == reader.readSchema(), - // "Schema of data source reader does not match the relation plan.") - - val requiredColumns = relation.output.filter(requiredByParent.contains) - reader.pruneColumns(requiredColumns.toStructType) - - val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - val newOutput = reader.readSchema().map(_.name).map(nameToAttr) - relation.copy(output = newOutput) - - case _ => relation +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { + override def apply( + plan: LogicalPlan): LogicalPlan = plan transformUp { + // PhysicalOperation guarantees that filters are deterministic; no need to check + case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => + // merge the filters + val filters = relation.filters match { + case Some(existing) => + existing ++ newFilters + case _ => + newFilters } - // TODO: there may be more operators that can be used to calculate the required columns. We - // can add more and more in the future. - case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) - } - } - - /** - * Finds a Filter node(with an optional Project child) above data source relation. - */ - object FilterAndProject { - // returns the project list, the filter condition and the data source relation. - def unapply(plan: LogicalPlan) - : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + val projectAttrs = project.map(_.toAttribute) + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(projectAttrs) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + projectAttrs + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } - case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + val newRelation = relation.copy( + projection = projection.asInstanceOf[Seq[AttributeReference]], + filters = Some(filters)) - case Filter(condition, Project(fields, r: DataSourceV2Relation)) - if fields.forall(_.deterministic) => - val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) - val substituted = condition.transform { - case a: Attribute => attributeMap.getOrElse(a, a) - } - Some((fields, substituted, r)) + // Add a Filter for any filters that could not be pushed + val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) + val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - case _ => None - } + // Add a Project to ensure the output matches the required projection + if (newRelation.output != projectAttrs) { + Project(project, filtered) + } else { + filtered + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..2c1d6c509d21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a1c87fb15542c..1157a350461d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -146,7 +146,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("A schema needs to be specified")) + assert(e.message.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..159dd0ecb5902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case StreamingDataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) From 862fa697d829cdddf0f25e5613c91b040f9d9652 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 20 Feb 2018 20:26:26 +0900 Subject: [PATCH 065/593] [SPARK-23240][PYTHON] Better error message when extraneous data in pyspark.daemon's stdout ## What changes were proposed in this pull request? Print more helpful message when daemon module's stdout is empty or contains a bad port number. ## How was this patch tested? Manually recreated the environmental issues that caused the mysterious exceptions at one site. Tested that the expected messages are logged. Also, ran all scala unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bruce Robbins Closes #20424 from bersprockets/SPARK-23240_prop2. --- .../api/python/PythonWorkerFactory.scala | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 30976ac752a8a..2340580b54f67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, EOFException, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.charset.StandardCharsets import java.util.Arrays @@ -182,7 +182,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) + val command = Arrays.asList(pythonExec, "-m", daemonModule) + val pb = new ProcessBuilder(command) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -191,7 +192,29 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() + try { + daemonPort = in.readInt() + } catch { + case _: EOFException => + throw new SparkException(s"No port number in $daemonModule's stdout") + } + + // test that the returned port number is within a valid range. + // note: this does not cover the case where the port number + // is arbitrary data but is also coincidentally within range + if (daemonPort < 1 || daemonPort > 0xffff) { + val exceptionMessage = f""" + |Bad data in $daemonModule's standard output. Invalid port number: + | $daemonPort (0x$daemonPort%08x) + |Python command to execute the daemon was: + | ${command.asScala.mkString(" ")} + |Check that you don't have any unexpected modules or libraries in + |your PYTHONPATH: + | $pythonPath + |Also, check if you have a sitecustomize.py module in your python path, + |or in your python installation, that is printing to standard output""" + throw new SparkException(exceptionMessage.stripMargin) + } // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) From 189f56f3dcdad4d997248c01aa5490617f018bd0 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 20 Feb 2018 07:51:30 -0600 Subject: [PATCH 066/593] [SPARK-23383][BUILD][MINOR] Make a distribution should exit with usage while detecting wrong options ## What changes were proposed in this pull request? ```shell ./dev/make-distribution.sh --name ne-1.0.0-SNAPSHOT xyz --tgz -Phadoop-2.7 +++ dirname ./dev/make-distribution.sh ++ cd ./dev/.. ++ pwd + SPARK_HOME=/Users/Kent/Documents/spark + DISTDIR=/Users/Kent/Documents/spark/dist + MAKE_TGZ=false + MAKE_PIP=false + MAKE_R=false + NAME=none + MVN=/Users/Kent/Documents/spark/build/mvn + (( 5 )) + case $1 in + NAME=ne-1.0.0-SNAPSHOT + shift + shift + (( 3 )) + case $1 in + break + '[' -z /Users/Kent/.jenv/candidates/java/current ']' + '[' -z /Users/Kent/.jenv/candidates/java/current ']' ++ command -v git + '[' /usr/local/bin/git ']' ++ git rev-parse --short HEAD + GITREV=98ea6a7 + '[' '!' -z 98ea6a7 ']' + GITREVSTRING=' (git revision 98ea6a7)' + unset GITREV ++ command -v /Users/Kent/Documents/spark/build/mvn + '[' '!' /Users/Kent/Documents/spark/build/mvn ']' ++ /Users/Kent/Documents/spark/build/mvn help:evaluate -Dexpression=project.version xyz --tgz -Phadoop-2.7 ++ grep -v INFO ++ tail -n 1 + VERSION=' -X,--debug Produce execution debug output' ``` It is better to declare the mistakes and exit with usage than `break` ## How was this patch tested? manually cc srowen Author: Kent Yao Closes #20571 from yaooqinn/SPARK-23383. --- dev/make-distribution.sh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 8b02446b2f15f..84233c64caa9c 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -72,9 +72,17 @@ while (( "$#" )); do --help) exit_with_usage ;; - *) + --*) + echo "Error: $1 is not supported" + exit_with_usage + ;; + -*) break ;; + *) + echo "Error: $1 is not supported" + exit_with_usage + ;; esac shift done From 83c008762af444eef73d835eb6f506ecf5aebc17 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 09:14:56 -0800 Subject: [PATCH 067/593] [SPARK-23456][SPARK-21783] Turn on `native` ORC impl and PPD by default ## What changes were proposed in this pull request? Apache Spark 2.3 introduced `native` ORC supports with vectorization and many fixes. However, it's shipped as a not-default option. This PR enables `native` ORC implementation and predicate-pushdown by default for Apache Spark 2.4. We will improve and stabilize ORC data source before Apache Spark 2.4. And, eventually, Apache Spark will drop old Hive-based ORC code. ## How was this patch tested? Pass the Jenkins with existing tests. Author: Dongjoon Hyun Closes #20634 from dongjoon-hyun/SPARK-23456. --- docs/sql-programming-guide.md | 6 +++++- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 91e43678481d6..c37c338a134f3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1018,7 +1018,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also spark.sql.orc.impl hive - The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1. + The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. spark.sql.orc.enableVectorizedReader @@ -1797,6 +1797,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see # Migration Guide +## Upgrading From Spark SQL 2.3 to 2.4 + + - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e75e1d66ebcf8..ce3f94618edeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default.") + "1.2.1. It is 'hive' by default prior to Spark 2.4.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("hive") + .createWithDefault("native") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + From 3e48f3b9ee7645e4218ad3ff7559e578d4bd9667 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 16:02:44 -0800 Subject: [PATCH 068/593] [SPARK-23434][SQL] Spark should not warn `metadata directory` for a HDFS file path ## What changes were proposed in this pull request? In a kerberized cluster, when Spark reads a file path (e.g. `people.json`), it warns with a wrong warning message during looking up `people.json/_spark_metadata`. The root cause of this situation is the difference between `LocalFileSystem` and `DistributedFileSystem`. `LocalFileSystem.exists()` returns `false`, but `DistributedFileSystem.exists` raises `org.apache.hadoop.security.AccessControlException`. ```scala scala> spark.version res0: String = 2.4.0-SNAPSHOT scala> spark.read.json("file:///usr/hdp/current/spark-client/examples/src/main/resources/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ scala> spark.read.json("hdfs:///tmp/people.json") 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. ``` After this PR, ```scala scala> spark.read.json("hdfs:///tmp/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20616 from dongjoon-hyun/SPARK-23434. --- .../spark/sql/execution/streaming/FileStreamSink.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e98..87a17cebdc10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -42,9 +42,11 @@ object FileStreamSink extends Logging { try { val hdfsPath = new Path(singlePath) val fs = hdfsPath.getFileSystem(hadoopConf) - val metadataPath = new Path(hdfsPath, metadataDir) - val res = fs.exists(metadataPath) - res + if (fs.isDirectory(hdfsPath)) { + fs.exists(new Path(hdfsPath, metadataDir)) + } else { + false + } } catch { case NonFatal(e) => logWarning(s"Error while looking for metadata directory.") From 2ba77ed9e51922303e3c3533e368b95788bd7de5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 17:54:06 -0800 Subject: [PATCH 069/593] [SPARK-23470][UI] Use first attempt of last stage to define job description. This is much faster than finding out what the last attempt is, and the data should be the same. There's room for improvement in this page (like only loading data for the jobs being shown, instead of loading all available jobs and sorting them), but this should bring performance on par with the 2.2 version. Author: Marcelo Vanzin Closes #20644 from vanzin/SPARK-23470. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index a9265d4dbcdfb..ac83de10f9237 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1048,7 +1048,7 @@ private[ui] object ApiHelper { } def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { - val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)) (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } From 6d398c05cbad69aa9093429e04ae44c73b81cd5a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 18:06:21 -0800 Subject: [PATCH 070/593] [SPARK-23468][CORE] Stringify auth secret before storing it in credentials. The secret is used as a string in many parts of the code, so it has to be turned into a hex string to avoid issues such as the random byte sequence not containing a valid UTF8 sequence. Author: Marcelo Vanzin Closes #20643 from vanzin/SPARK-23468. --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 4c1dbe3ffb4ad..5b15a1c57779d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -541,7 +541,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretBytes) + val secretStr = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From 601d653bff9160db8477f86d961e609fc2190237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 20 Feb 2018 18:16:10 -0800 Subject: [PATCH 071/593] [SPARK-23454][SS][DOCS] Added trigger information to the Structured Streaming programming guide ## What changes were proposed in this pull request? - Added clear information about triggers - Made the semantics guarantees of watermarks more clear for streaming aggregations and stream-stream joins. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tathagata Das Closes #20631 from tdas/SPARK-23454. --- .../structured-streaming-programming-guide.md | 214 +++++++++++++++++- 1 file changed, 207 insertions(+), 7 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 48d6d0b542cc0..9a83f157452ad 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -904,7 +904,7 @@ windowedCounts <- count(
    -### Handling Late Data and Watermarking +#### Handling Late Data and Watermarking Now consider what happens if one of the events arrives late to the application. For example, say, a word generated at 12:04 (i.e. event time) could be received by the application at 12:11. The application should use the time 12:04 instead of 12:11 @@ -925,7 +925,9 @@ specifying the event time column and the threshold on how late the data is expec event time. For a specific window starting at time `T`, the engine will maintain state and allow late data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, -but data later than the threshold will be dropped. Let's understand this with an example. We can +but data later than the threshold will start getting dropped +(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below.
    @@ -1031,7 +1033,9 @@ then drops intermediate state of a window < watermark, and appends the final counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is appended to the Result Table only after the watermark is updated to `12:11`. -**Conditions for watermarking to clean aggregation state** +##### Conditions for watermarking to clean aggregation state +{:.no_toc} + It is important to note that the following conditions must be satisfied for the watermarking to clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*. @@ -1051,6 +1055,16 @@ from the aggregation column. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append output mode. +##### Semantic Guarantees of Aggregation with Watermarking +{:.no_toc} + +- A watermark delay (set with `withWatermark`) of "2 hours" guarantees that the engine will never +drop any data that is less than 2 hours delayed. In other words, any data less than 2 hours behind +(in terms of event-time) the latest data processed till then is guaranteed to be aggregated. + +- However, the guarantee is strict only in one direction. Data delayed by more than 2 hours is +not guaranteed to be dropped; it may or may not get aggregated. More delayed is the data, less +likely is the engine going to process it. ### Join Operations Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame @@ -1062,7 +1076,7 @@ Dataset/DataFrame will be the exactly the same as if it was with a static Datase containing the same data in the stream. -#### Stream-static joins +#### Stream-static Joins Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example. @@ -1269,6 +1283,12 @@ joined <- join(
+###### Semantic Guarantees of Stream-stream Inner Joins with Watermarking +{:.no_toc} +This is similar to the [guarantees provided by watermarking on aggregations](#semantic-guarantees-of-aggregation-with-watermarking). +A watermark delay of "2 hours" guarantees that the engine will never drop any data that is less than + 2 hours delayed. But data delayed by more than 2 hours may or may not get processed. + ##### Outer Joins with Watermarking While the watermark + event-time constraints is optional for inner joins, for left and right outer joins they must be specified. This is because for generating the NULL results in outer join, the @@ -1347,7 +1367,14 @@ joined <- join(
-There are a few points to note regarding outer joins. +###### Semantic Guarantees of Stream-stream Outer Joins with Watermarking +{:.no_toc} +Outer joins have the same guarantees as [inner joins](#semantic-guarantees-of-stream-stream-inner-joins-with-watermarking) +regarding watermark delays and whether data will be dropped or not. + +###### Caveats +{:.no_toc} +There are a few important characteristics to note regarding how the outer results are generated. - *The outer NULL results will be generated with a delay that depends on the specified watermark delay and the time range condition.* This is because the engine has to wait for that long to ensure @@ -1962,7 +1989,7 @@ head(sql("select * from aggregates")) -#### Using Foreach +##### Using Foreach The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. @@ -1979,6 +2006,172 @@ which has methods that get called whenever there is a sequence of rows generated - Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. +#### Triggers +The trigger settings of a streaming query defines the timing of streaming data processing, whether +the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query. +Here are the different kinds of triggers that are supported. + + + + + + + + + + + + + + + + + + + + + + +
Trigger TypeDescription
unspecified (default) + If no trigger setting is explicitly specified, then by default, the query will be + executed in micro-batch mode, where micro-batches will be generated as soon as + the previous micro-batch has completed processing. +
Fixed interval micro-batches + The query will be executed with micro-batches mode, where micro-batches will be kicked off + at the user-specified intervals. +
    +
  • If the previous micro-batch completes within the interval, then the engine will wait until + the interval is over before kicking off the next micro-batch.
  • + +
  • If the previous micro-batch takes longer than the interval to complete (i.e. if an + interval boundary is missed), then the next micro-batch will start as soon as the + previous one completes (i.e., it will not wait for the next interval boundary).
  • + +
  • If no new data is available, then no micro-batch will be kicked off.
  • +
+
One-time micro-batch + The query will execute *only one* micro-batch to process all the available data and then + stop on its own. This is useful in scenarios you want to periodically spin up a cluster, + process everything that is available since the last period, and then shutdown the + cluster. In some case, this may lead to significant cost savings. +
Continuous with fixed checkpoint interval
(experimental)
+ The query will be executed in the new low-latency, continuous processing mode. Read more + about this in the Continuous Processing section below. +
+ +Here are a few code examples. + +
+
+ +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start() + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start() + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start() + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start() + +{% endhighlight %} + + +
+
+ +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start(); + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start(); + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start(); + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start(); + +{% endhighlight %} + +
+
+ +{% highlight python %} + +# Default trigger (runs micro-batch as soon as it can) +df.writeStream \ + .format("console") \ + .start() + +# ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream \ + .format("console") \ + .trigger(processingTime='2 seconds') \ + .start() + +# One-time trigger +df.writeStream \ + .format("console") \ + .trigger(once=True) \ + .start() + +# Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(continuous='1 second') + .start() + +{% endhighlight %} +
+
+ +{% highlight r %} +# Default trigger (runs micro-batch as soon as it can) +write.stream(df, "console") + +# ProcessingTime trigger with two-seconds micro-batch interval +write.stream(df, "console", trigger.processingTime = "2 seconds") + +# One-time trigger +write.stream(df, "console", trigger.once = TRUE) + +# Continuous trigger is not yet supported +{% endhighlight %} +
+
+ + ## Managing Streaming Queries The `StreamingQuery` object created when a query is started can be used to monitor and manage the query. @@ -2516,7 +2709,10 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat -# Continuous Processing [Experimental] +# Continuous Processing +## [Experimental] +{:.no_toc} + **Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, @@ -2589,6 +2785,8 @@ spark \ A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. ## Supported Queries +{:.no_toc} + As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. - *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). @@ -2606,6 +2804,8 @@ As of Spark 2.3, only the following type of queries are supported in the continu See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. ## Caveats +{:.no_toc} + - Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. - Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. - There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. From 95e25ed1a8b56937345eff637c0032aea85a503d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 21 Feb 2018 11:26:06 +0800 Subject: [PATCH 072/593] [SPARK-23424][SQL] Add codegenStageId in comment ## What changes were proposed in this pull request? This PR always adds `codegenStageId` in comment of the generated class. This is a replication of #20419 for post-Spark 2.3. Closes #20419 ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; ... ``` ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20612 from kiszk/SPARK-23424. --- .../expressions/codegen/CodeGenerator.scala | 21 ++++++++++++++++--- .../sql/execution/WholeStageCodegenExec.scala | 4 +++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 31ba29ae8d8ce..60a6f50472504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1232,14 +1232,29 @@ class CodegenContext { /** * Register a comment and return the corresponding place holder + * + * @param placeholderId an optionally specified identifier for the comment's placeholder. + * The caller should make sure this identifier is unique within the + * compilation unit. If this argument is not specified, a fresh identifier + * will be automatically created and used as the placeholder. + * @param force whether to force registering the comments */ - def registerComment(text: => String): String = { + def registerComment( + text: => String, + placeholderId: String = "", + force: Boolean = false): String = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this // flat, see SPARK-15680. - if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { - val name = freshName("c") + if (force || + SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = if (placeholderId != "") { + assert(!placeHolderToComments.contains(placeholderId)) + placeholderId + } else { + freshName("c") + } val comment = if (text.contains("\n") || text.contains("\r")) { text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0e525b1e22eb9..deb0a044c2fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -540,7 +540,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) ${ctx.registerComment( s"""Codegend pipeline for stage (id=$codegenStageId) - |${this.treeString.trim}""".stripMargin)} + |${this.treeString.trim}""".stripMargin, + "wsc_codegenPipeline")} + ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; From c8c4441dfdfeda22f8d92e25aee1b6a6269752f9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 21 Feb 2018 15:10:08 +0800 Subject: [PATCH 073/593] [SPARK-23418][SQL] Fail DataSourceV2 reads when user schema is passed, but not supported. ## What changes were proposed in this pull request? DataSourceV2 initially allowed user-supplied schemas when a source doesn't implement `ReadSupportWithSchema`, as long as the schema was identical to the source's schema. This is confusing behavior because changes to an underlying table can cause a previously working job to fail with an exception that user-supplied schemas are not allowed. This reverts commit adcb25a0624, which was added to #20387 so that it could be removed in a separate JIRA issue and PR. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #20603 from rdblue/SPARK-23418-revert-adcb25a0624. --- .../datasources/v2/DataSourceV2Relation.scala | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a98dd4866f82a..cc6cb631e3f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -174,13 +174,6 @@ object DataSourceV2Relation { v2Options: DataSourceOptions, userSchema: Option[StructType]): StructType = { val reader = userSchema match { - // TODO: remove this case because it is confusing for users - case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => - val reader = source.asReadSupport.createReader(v2Options) - if (reader.readSchema() != s) { - throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") - } - reader case Some(s) => source.asReadSupportWithSchema.createReader(s, v2Options) case _ => @@ -195,11 +188,7 @@ object DataSourceV2Relation { filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, - // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must - // be equal to the reader's schema. the schema method enforces this. because the user schema - // and the reader's schema are identical, drop the user schema. - if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) } private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { From e836c27ce011ca9aef822bef6320b4a7059ec343 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Feb 2018 12:39:36 -0600 Subject: [PATCH 074/593] [SPARK-23217][ML][PYTHON] Add distanceMeasure param to ClusteringEvaluator Python API ## What changes were proposed in this pull request? The PR adds the `distanceMeasure` param to ClusteringEvaluator in the Python API. This allows the user to specify `cosine` as distance measure in addition to the default `squaredEuclidean`. ## How was this patch tested? added UT Author: Marco Gaido Closes #20627 from mgaido91/SPARK-23217_python. --- python/pyspark/ml/evaluation.py | 28 +++++++++++++++++++++++----- python/pyspark/ml/tests.py | 16 ++++++++++++++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 0cbce9b40048f..695d8ab27cc96 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -362,18 +362,21 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (silhouette)", typeConverter=TypeConverters.toString) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'squaredEuclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ __init__(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") """ super(ClusteringEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) - self._setDefault(metricName="silhouette") + self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean") kwargs = self._input_kwargs self._set(**kwargs) @@ -394,15 +397,30 @@ def getMetricName(self): @keyword_only @since("2.3.0") def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") Sets params for clustering evaluator. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + + if __name__ == "__main__": import doctest import tempfile diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6d6737241e06e..116885969345c 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -51,7 +51,7 @@ from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ +from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel @@ -541,6 +541,15 @@ def test_java_params(self): self.assertEqual(evaluator._java_obj.getMetricName(), "r2") self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + class FeatureTests(SparkSessionTestCase): @@ -1961,11 +1970,14 @@ def test_java_params(self): import pyspark.ml.feature import pyspark.ml.classification import pyspark.ml.clustering + import pyspark.ml.evaluation import pyspark.ml.pipeline import pyspark.ml.recommendation import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): if not name.endswith('Model') and issubclass(cls, JavaParams)\ From 3fd0ccb13fea44727d970479af1682ef00592147 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Feb 2018 14:56:13 -0800 Subject: [PATCH 075/593] [SPARK-23484][SS] Fix possible race condition in KafkaContinuousReader ## What changes were proposed in this pull request? var `KafkaContinuousReader.knownPartitions` should be threadsafe as it is accessed from multiple threads - the query thread at the time of reader factory creation, and the epoch tracking thread at the time of `needsReconfiguration`. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20655 from tdas/SPARK-23484. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 97a0f66e1880d..ecd1170321f3f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -66,7 +66,7 @@ class KafkaContinuousReader( // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ override def readSchema: StructType = KafkaOffsetReader.kafkaSchema From 744d5af652ee8cece361cbca31e5201134e0fb42 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 15:37:28 -0800 Subject: [PATCH 076/593] [SPARK-23481][WEBUI] lastStageAttempt should fail when a stage doesn't exist ## What changes were proposed in this pull request? The issue here is `AppStatusStore.lastStageAttempt` will return the next available stage in the store when a stage doesn't exist. This PR adds `last(stageId)` to ensure it returns a correct `StageData` ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #20654 from zsxwing/SPARK-23481. --- .../apache/spark/status/AppStatusStore.scala | 6 +++- .../spark/status/AppStatusListenerSuite.scala | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index efc28538a33db..688f25a9fdea1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -95,7 +95,11 @@ private[spark] class AppStatusStore( } def lastStageAttempt(stageId: Int): v1.StageData = { - val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + val it = store.view(classOf[StageDataWrapper]) + .index("stageId") + .reverse() + .first(stageId) + .last(stageId) .closeableIterator() try { if (it.hasNext()) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 749502709b5c8..673d191b5a4db 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1121,6 +1121,39 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("lastStageAttempt should fail when the stage doesn't exist") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1) + val listener = new AppStatusListener(store, testConf, true) + val appStore = new AppStatusStore(store) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Make stage 3 complete before stage 2 so that stage 3 will be evicted + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + stage3.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage3)) + + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + + assert(appStore.asOption(appStore.lastStageAttempt(1)) === None) + assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2)) + assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) From 45cf714ee6d4eead2fe00794a0d754fa6d33d4a6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 19:43:11 -0800 Subject: [PATCH 077/593] [SPARK-23475][WEBUI] Skipped stages should be evicted before completed stages ## What changes were proposed in this pull request? The root cause of missing completed stages is because `cleanupStages` will never remove skipped stages. This PR changes the logic to always remove skipped stage first. This is safe since the job itself contains enough information to render skipped stages in the UI. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #20656 from zsxwing/SPARK-23475. --- .../spark/status/AppStatusListener.scala | 5 ++- .../spark/status/AppStatusListenerSuite.scala | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 79a17e26665fd..5ea161cd0d151 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -915,7 +915,10 @@ private[spark] class AppStatusListener( return } - val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + // As the completion time of a skipped stage is always -1, we will remove skipped stages first. + // This is safe since the job itself contains enough information to render skipped stages in the + // UI. + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime") val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 673d191b5a4db..1cd71955ad4d9 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1089,6 +1089,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("skipped stages should be evicted before completed stages") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Sart job 1 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start and stop stage 1 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 and stage 2 will become SKIPPED + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Submit stage 3 and verify stage 2 is evicted + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + test("eviction should respect task completion time") { val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) val listener = new AppStatusListener(store, testConf, true) From 87293c746e19d66f475d506d0adb43421f496843 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Feb 2018 11:00:12 -0800 Subject: [PATCH 078/593] [SPARK-23475][UI] Show also skipped stages ## What changes were proposed in this pull request? SPARK-20648 introduced the status `SKIPPED` for the stages. On the UI, previously, skipped stages were shown as `PENDING`; after this change, they are not shown on the UI. The PR introduce a new section in order to show also `SKIPPED` stages in a proper table. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20651 from mgaido91/SPARK-23475. --- .../org/apache/spark/ui/static/webui.js | 1 + .../apache/spark/ui/jobs/AllStagesPage.scala | 27 +++++++++++++++++++ .../org/apache/spark/ui/UISeleniumSuite.scala | 17 ++++++++++++ 3 files changed, 45 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 83009df91d30a..f01c567ba58ad 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -72,6 +72,7 @@ $(function() { collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allSkippedStages','aggregated-allSkippedStages'); collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 606dc1e180e5b..38450b9126ff0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -36,6 +36,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) val pendingStages = allStages.filter(_.status == StageStatus.PENDING) + val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse @@ -51,6 +52,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val completedStagesTable = new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", parent.basePath, subPath, parent.isFairScheduler, false, false) + val skippedStagesTable = + new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) val failedStagesTable = new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", parent.basePath, subPath, parent.isFairScheduler, false, true) @@ -66,6 +70,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val shouldShowActiveStages = activeStages.nonEmpty val shouldShowPendingStages = pendingStages.nonEmpty val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = skippedStages.nonEmpty val shouldShowFailedStages = failedStages.nonEmpty val appSummary = parent.store.appSummary() @@ -102,6 +107,14 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { } } + { + if (shouldShowSkippedStages) { +
  • + Skipped Stages: + {skippedStages.size} +
  • + } + } { if (shouldShowFailedStages) {
  • @@ -172,6 +185,20 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { {completedStagesTable.toNodeSeq} } + if (shouldShowSkippedStages) { + content ++= + +

    + + Skipped Stages ({skippedStages.size}) +

    +
    ++ +
    + {skippedStagesTable.toNodeSeq} +
    + } if (shouldShowFailedStages) { content ++= + val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache() + rdd.count() + rdd.count() + + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/stages") + find(id("skipped")).get.text should be("Skipped Stages (1)") + } + val stagesJson = getJson(sc.ui.get, "stages") + stagesJson.children.size should be (4) + val stagesStatus = stagesJson.children.map(_ \ "status") + stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } From c5abb3c2d16f601d507bee3c53663d4e117eb8b5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 22 Feb 2018 12:07:51 -0800 Subject: [PATCH 079/593] [SPARK-23476][CORE] Generate secret in local mode when authentication on ## What changes were proposed in this pull request? If spark is run with "spark.authenticate=true", then it will fail to start in local mode. This PR generates secret in local mode when authentication on. ## How was this patch tested? Modified existing unit test. Manually started spark-shell. Author: Gabor Somogyi Closes #20652 from gaborgsomogyi/SPARK-23476. --- .../org/apache/spark/SecurityManager.scala | 16 ++++-- .../apache/spark/SecurityManagerSuite.scala | 50 +++++++++++++------ docs/security.md | 2 +- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 5b15a1c57779d..2519d266879aa 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -520,19 +520,25 @@ private[spark] class SecurityManager( * * If authentication is disabled, do nothing. * - * In YARN mode, generate a new secret and store it in the current user's credentials. + * In YARN and local mode, generate a new secret and store it in the current user's credentials. * * In other modes, assert that the auth secret is set in the configuration. */ def initializeAuth(): Unit = { + import SparkMasterRegex._ + if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return } - if (sparkConf.get(SparkLauncher.SPARK_MASTER, null) != "yarn") { - require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), - s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") - return + val master = sparkConf.get(SparkLauncher.SPARK_MASTER, "") + master match { + case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => + // Secret generation allowed here + case _ => + require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") + return } val rnd = new SecureRandom() diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index cf59265dd646d..106ece7aed0a4 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -440,23 +440,41 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } - test("secret key generation in yarn mode") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(SparkLauncher.SPARK_MASTER, "yarn") - val mgr = new SecurityManager(conf) - - UserGroupInformation.createUserForTesting("authTest", Array()).doAs( - new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - mgr.initializeAuth() - val creds = UserGroupInformation.getCurrentUser().getCredentials() - val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) - assert(secret != null) - assert(new String(secret, UTF_8) === mgr.getSecretKey()) + test("secret key generation") { + Seq( + ("yarn", true), + ("local", true), + ("local[*]", true), + ("local[1, 2]", true), + ("local-cluster[2, 1, 1024]", false), + ("invalid", false) + ).foreach { case (master, shouldGenerateSecret) => + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(SparkLauncher.SPARK_MASTER, master) + val mgr = new SecurityManager(conf) + + UserGroupInformation.createUserForTesting("authTest", Array()).doAs( + new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + if (shouldGenerateSecret) { + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + } else { + intercept[IllegalArgumentException] { + mgr.initializeAuth() + } + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } + } + } } - } - ) + ) + } } } diff --git a/docs/security.md b/docs/security.md index bebc28ddbfb0e..0f384b411812a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,7 +6,7 @@ title: Security Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. +* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. ## Web UI From 049f243c59737699fee54fdc9d65cbd7c788032a Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 22 Feb 2018 21:49:25 -0800 Subject: [PATCH 080/593] [SPARK-23490][SQL] Check storage.locationUri with existing table in CreateTable ## What changes were proposed in this pull request? For CreateTable with Append mode, we should check if `storage.locationUri` is the same with existing table in `PreprocessTableCreation` In the current code, there is only a simple exception if the `storage.locationUri` is different with existing table: `org.apache.spark.sql.AnalysisException: Table or view not found:` which can be improved. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20660 from gengliangwang/locationUri. --- .../sql/execution/datasources/rules.scala | 8 +++++ .../sql/execution/command/DDLSuite.scala | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5cc21eeaeaa94..0dea767840ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -118,6 +118,14 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + s"`${specifiedProvider.getSimpleName}`.") } + tableDesc.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw new AnalysisException( + s"The location of the existing table ${tableIdentWithDB.quotedString} is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${tableDesc.location}`.") + case _ => + } if (query.schema.length != existingTable.schema.length) { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f76bfd2fda2b9..b800e6ff5b0ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -536,6 +536,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create table - append to a non-partitioned table created with different paths") { + import testImplicits._ + withTempDir { dir1 => + withTempDir { dir2 => + withTable("path_test") { + Seq(1L -> "a").toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir1.getCanonicalPath) + .saveAsTable("path_test") + + val ex = intercept[AnalysisException] { + Seq((3L, "c")).toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir2.getCanonicalPath) + .saveAsTable("path_test") + }.getMessage + assert(ex.contains("The location of the existing table `default`.`path_test`")) + + checkAnswer( + spark.table("path_test"), Row(1L, "a") :: Nil) + } + } + } + } + test("Refresh table after changing the data source table partitioning") { import testImplicits._ From 855ce13d045569b7b16fdc7eee9c981f4ff3a545 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Feb 2018 12:40:58 -0800 Subject: [PATCH 081/593] [SPARK-23408][SS] Synchronize successive AddData actions in Streaming*JoinSuite **The best way to review this PR is to ignore whitespace/indent changes. Use this link - https://github.com/apache/spark/pull/20650/files?w=1** ## What changes were proposed in this pull request? The stream-stream join tests add data to multiple sources and expect it all to show up in the next batch. But there's a race condition; the new batch might trigger when only one of the AddData actions has been reached. Prior attempt to solve this issue by jose-torres in #20646 attempted to simultaneously synchronize on all memory sources together when consecutive AddData was found in the actions. However, this carries the risk of deadlock as well as unintended modification of stress tests (see the above PR for a detailed explanation). Instead, this PR attempts the following. - A new action called `StreamProgressBlockedActions` that allows multiple actions to be executed while the streaming query is blocked from making progress. This allows data to be added to multiple sources that are made visible simultaneously in the next batch. - An alias of `StreamProgressBlockedActions` called `MultiAddData` is explicitly used in the `Streaming*JoinSuites` to add data to two memory sources simultaneously. This should avoid unintentional modification of the stress tests (or any other test for that matter) while making sure that the flaky tests are deterministic. ## How was this patch tested? Modified test cases in `Streaming*JoinSuites` where there are consecutive `AddData` actions. Author: Tathagata Das Closes #20650 from tdas/SPARK-23408. --- .../streaming/MicroBatchExecution.scala | 10 + .../spark/sql/streaming/StreamTest.scala | 472 ++++++++++-------- .../sql/streaming/StreamingJoinSuite.scala | 54 +- 3 files changed, 284 insertions(+), 252 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84655013ba957..6bd03972c301d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -504,6 +504,16 @@ class MicroBatchExecution( } } + /** Execute a function while locking the stream from making an progress */ + private[sql] def withProgressLocked(f: => Unit): Unit = { + awaitProgressLock.lock() + try { + f + } finally { + awaitProgressLock.unlock() + } + } + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { Optional.ofNullable(scalaOption.orNull) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 159dd0ecb5902..08f722ecb10e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -102,6 +102,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be AddDataMemory(source, data) } + /** + * Adds data to multiple memory streams such that all the data will be made visible in the + * same batch. This is applicable only to MicroBatchExecution, as this coordination cannot be + * performed at the driver in ContinuousExecutions. + */ + object MultiAddData { + def apply[A] + (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = { + val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2)) + StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]")) + } + } + /** A trait that can be extended when testing a source. */ trait AddData extends StreamAction { /** @@ -217,6 +230,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]" } + /** + * Performs multiple actions while locking the stream from progressing. + * This is applicable only to MicroBatchExecution, as progress of ContinuousExecution + * cannot be controlled from the driver. + */ + case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: String = null) + extends StreamAction { + + override def toString(): String = { + if (desc != null) desc else super.toString + } + } + /** Assert that a body is true */ class Assert(condition: => Boolean, val message: String = "") extends StreamAction { def run(): Unit = { Assertions.assert(condition) } @@ -295,6 +321,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + var manualClockExpectedTime = -1L @volatile var streamThreadDeathCause: Throwable = null @@ -425,243 +454,254 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - var manualClockExpectedTime = -1L - val defaultCheckpointLocation = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - try { - startedTest.foreach { action => - logInfo(s"Processing test stream action: $action") - action match { - case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") - verify(triggerClock.isInstanceOf[SystemClock] - || triggerClock.isInstanceOf[StreamManualClock], - "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + def executeAction(action: StreamAction): Unit = { + logInfo(s"Processing test stream action: $action") + action match { + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => + verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + + additionalConfs.foreach(pair => { + val value = + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None + resetConfValues(pair._1) = value + sparkSession.conf.set(pair._1, pair._2) + }) + + lastStream = currentStream + currentStream = + sparkSession + .streams + .startQuery( + None, + Some(metadataRoot), + stream, + Map(), + sink, + outputMode, + trigger = trigger, + triggerClock = triggerClock) + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + // Wait until the initialization finishes, because some tests need to use `logicalPlan` + // after starting the query. + try { + currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + assert(s.lastExecution != null) + } + case _ => } - val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + } catch { + case _: StreamingQueryException => + // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. + } - additionalConfs.foreach(pair => { - val value = - if (sparkSession.conf.contains(pair._1)) { - Some(sparkSession.conf.get(pair._1)) - } else None - resetConfValues(pair._1) = value - sparkSession.conf.set(pair._1, pair._2) - }) + case AdvanceManualClock(timeToAdd) => + verify(currentStream != null, + "can not advance manual clock when a stream is not running") + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], + s"can not advance clock of type ${currentStream.triggerClock.getClass}") + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) + + // Make sure we don't advance ManualClock too early. See SPARK-16002. + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + } + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") + + case StopStream => + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.queryExecutionThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest( + "Timed out while stopping and waiting for microbatchthread to terminate.", e) + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { lastStream = currentStream - currentStream = - sparkSession - .streams - .startQuery( - None, - Some(metadataRoot), - stream, - Map(), - sink, - outputMode, - trigger = trigger, - triggerClock = triggerClock) - .asInstanceOf[StreamingQueryWrapper] - .streamingQuery - // Wait until the initialization finishes, because some tests need to use `logicalPlan` - // after starting the query. - try { - currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - assert(s.lastExecution != null) - } - case _ => - } - } catch { - case _: StreamingQueryException => - // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. - } + currentStream = null + } - case AdvanceManualClock(timeToAdd) => - verify(currentStream != null, - "can not advance manual clock when a stream is not running") - verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], - s"can not advance clock of type ${currentStream.triggerClock.getClass}") - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - assert(manualClockExpectedTime >= 0) - - // Make sure we don't advance ManualClock too early. See SPARK-16002. - eventually("StreamManualClock has not yet entered the waiting state") { - assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[StreamingQueryException] { + currentStream.awaitTermination() } - - clock.advance(timeToAdd) - manualClockExpectedTime += timeToAdd - verify(clock.getTimeMillis() === manualClockExpectedTime, - s"Unexpected clock time after updating: " + - s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") - - case StopStream => - verify(currentStream != null, "can not stop a stream that is not running") - try failAfter(streamingTimeout) { - currentStream.stop() - verify(!currentStream.queryExecutionThread.isAlive, - s"microbatch thread not stopped") - verify(!currentStream.isActive, - "query.isActive() is false even after stopping") - verify(currentStream.exception.isEmpty, - s"query.exception() is not empty after clean stop: " + - currentStream.exception.map(_.toString()).getOrElse("")) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest( - "Timed out while stopping and waiting for microbatchthread to terminate.", e) - case t: Throwable => - failTest("Error while stopping stream", t) - } finally { - lastStream = currentStream - currentStream = null + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.queryExecutionThread.isAlive) } + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + if (ef.isFatalError) { + // This is a fatal error, `streamThreadDeathCause` should be set to this error in + // UncaughtExceptionHandler. + verify(streamThreadDeathCause != null && + streamThreadDeathCause.getClass === ef.causeClass, + "UncaughtExceptionHandler didn't receive the correct error\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") + streamThreadDeathCause = null + } + ef.assertFailure(exception.getCause) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure", e) + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + } - case ef: ExpectFailure[_] => - verify(currentStream != null, "can not expect failure when stream is not running") - try failAfter(streamingTimeout) { - val thrownException = intercept[StreamingQueryException] { - currentStream.awaitTermination() - } - eventually("microbatch thread not stopped after termination with failure") { - assert(!currentStream.queryExecutionThread.isAlive) + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when no stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") + + case a: AddData => + try { + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } } - verify(currentStream.exception === Some(thrownException), - s"incorrect exception returned by query.exception()") - - val exception = currentStream.exception.get - verify(exception.cause.getClass === ef.causeClass, - "incorrect cause in exception returned by query.exception()\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") - if (ef.isFatalError) { - // This is a fatal error, `streamThreadDeathCause` should be set to this error in - // UncaughtExceptionHandler. - verify(streamThreadDeathCause != null && - streamThreadDeathCause.getClass === ef.causeClass, - "UncaughtExceptionHandler didn't receive the correct error\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") - streamThreadDeathCause = null + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") } - ef.assertFailure(exception.getCause) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while waiting for failure", e) - case t: Throwable => - failTest("Error while checking stream failure", t) - } finally { - lastStream = currentStream - currentStream = null } - - case a: AssertOnQuery => - verify(currentStream != null || lastStream != null, - "cannot assert when no stream has been started") - val streamToAssert = Option(currentStream).getOrElse(lastStream) - try { - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") - } catch { - case NonFatal(e) => - failTest(s"Assert on query failed: ${a.message}", e) + // Add data + val queryToUse = Option(currentStream).orElse(Option(lastStream)) + val (source, offset) = a.addData(queryToUse) + + def findSourceIndex(plan: LogicalPlan): Option[Int] = { + plan + .collect { + case StreamingExecutionRelation(s, _) => s + case StreamingDataSourceV2Relation(_, r) => r + } + .zipWithIndex + .find(_._1 == source) + .map(_._2) } - case a: Assert => - val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify({ a.run(); true }, s"Assert failed: ${a.message}") - - case a: AddData => - try { - - // If the query is running with manual clock, then wait for the stream execution - // thread to start waiting for the clock to increment. This is needed so that we - // are adding data when there is no trigger that is active. This would ensure that - // the data gets deterministically added to the next batch triggered after the manual - // clock is incremented in following AdvanceManualClock. This avoid race conditions - // between the test thread and the stream execution thread in tests using manual - // clock. - if (currentStream != null && - currentStream.triggerClock.isInstanceOf[StreamManualClock]) { - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - eventually("Error while synchronizing with manual clock before adding data") { - if (currentStream.isActive) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + // Try to find the index of the source to which data was added. Either get the index + // from the current active query or the original input logical plan. + val sourceIndex = + queryToUse.flatMap { query => + findSourceIndex(query.logicalPlan) + }.orElse { + findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) } - if (!currentStream.isActive) { - failTest("Query terminated while synchronizing with manual clock") - } - } - // Add data - val queryToUse = Option(currentStream).orElse(Option(lastStream)) - val (source, offset) = a.addData(queryToUse) - - def findSourceIndex(plan: LogicalPlan): Option[Int] = { - plan - .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r - } - .zipWithIndex - .find(_._1 == source) - .map(_._2) + }.getOrElse { + throw new IllegalArgumentException( + "Could not find index of the source to which data was added") } - // Try to find the index of the source to which data was added. Either get the index - // from the current active query or the original input logical plan. - val sourceIndex = - queryToUse.flatMap { query => - findSourceIndex(query.logicalPlan) - }.orElse { - findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } - }.getOrElse { - throw new IllegalArgumentException( - "Could not find index of the source to which data was added") - } + // Store the expected offset of added data to wait for it later + awaiting.put(sourceIndex, offset) + } catch { + case NonFatal(e) => + failTest("Error adding data", e) + } - // Store the expected offset of added data to wait for it later - awaiting.put(sourceIndex, offset) - } catch { - case NonFatal(e) => - failTest("Error adding data", e) - } + case e: ExternalAction => + e.runAction() - case e: ExternalAction => - e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { + error => failTest(error) + } - case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { - error => failTest(error) - } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } - case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = currentStream match { - case null => fetchStreamAnswer(lastStream, lastOnly) - case s => fetchStreamAnswer(s, lastOnly) - } - QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { - error => failTest(error) - } + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) + } + } + pos += 1 + } - case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - try { - globalCheckFunction(sparkAnswer) - } catch { - case e: Throwable => failTest(e.toString) - } - } - pos += 1 + try { + startedTest.foreach { + case StreamProgressLockedActions(actns, _) => + // Perform actions while holding the stream from progressing + assert(currentStream != null, + s"Cannot perform stream-progress-locked actions $actns when query is not active") + assert(currentStream.isInstanceOf[MicroBatchExecution], + s"Cannot perform stream-progress-locked actions on non-microbatch queries") + currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { + actns.foreach(executeAction) + } + + case action: StreamAction => executeAction(action) } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 92087f68ad74a..11bdd13942dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -462,15 +462,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -493,15 +491,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with value <= 7 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) @@ -524,15 +520,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with value <= 4 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) @@ -555,15 +549,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -575,13 +567,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -595,13 +585,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -676,11 +664,9 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), @@ -688,22 +674,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows - AddData(leftInput, 40, 41), - AddData(rightInput, 40, 41), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - AddData(leftInput, 70), - AddData(rightInput, 71), + MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), assertNumStateRows(total = 6, updated = 2), AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), // rightValue between 300 and 1000 should generate outer join rows even though it matches left - AddData(leftInput, 101, 102, 103), - AddData(rightInput, 101, 102, 103), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), CheckLastBatch(), - AddData(leftInput, 1000), - AddData(rightInput, 1001), + MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), assertNumStateRows(total = 8, updated = 2), AddData(rightInput, 1000), From 1a198ce8f580bcf35b9cbfab403fc40f821046a1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 23 Feb 2018 16:30:32 -0800 Subject: [PATCH 082/593] [SPARK-23459][SQL] Improve the error message when unknown column is specified in partition columns ## What changes were proposed in this pull request? This PR avoids to print schema internal information when unknown column is specified in partition columns. This PR prints column names in the schema with more readable format. The following is an example. Source code ``` test("save with an unknown partition column") { withTempDir { dir => val path = dir.getCanonicalPath Seq(1L -> "a").toDF("i", "j").write .format("parquet") .partitionBy("unknownColumn") .save(path) } ``` Output without this PR ``` Partition column unknownColumn not found in schema StructType(StructField(i,LongType,false), StructField(j,StringType,true)); ``` Output with this PR ``` Partition column unknownColumn not found in schema struct; ``` ## How was this patch tested? Manually tested Author: Kazuaki Ishizaki Closes #20653 from kiszk/SPARK-23459. --- .../datasources/PartitioningUtils.scala | 3 ++- .../apache/spark/sql/sources/SaveLoadSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 379acb67f7c71..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -486,7 +486,8 @@ object PartitioningUtils { val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => schema.find(f => equality(f.name, col)).getOrElse { - throw new AnalysisException(s"Partition column $col not found in schema $schema") + val schemaCatalog = schema.catalogString + throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog") } }).asNullable } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 773d34dfaf9a8..12779b46bfe8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -126,4 +126,20 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA checkLoad(df2, "jsonTable2") } + + test("SPARK-23459: Improve error message when specified unknown column in partition columns") { + withTempDir { dir => + val path = dir.getCanonicalPath + val unknown = "unknownColumn" + val df = Seq(1L -> "a").toDF("i", "j") + val schemaCatalog = df.schema.catalogString + val e = intercept[AnalysisException] { + df.write + .format("parquet") + .partitionBy(unknown) + .save(path) + }.getMessage + assert(e.contains(s"Partition column `$unknown` not found in schema $schemaCatalog")) + } + } } From 3ca9a2c56513444d7b233088b020d2d43fa6b77a Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Sun, 25 Feb 2018 09:29:59 -0600 Subject: [PATCH 083/593] =?UTF-8?q?[SPARK-22886][ML][TESTS]=20ML=20test=20?= =?UTF-8?q?for=20structured=20streaming:=20ml.recomme=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Converting spark.ml.recommendation tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20362 from gaborgsomogyi/SPARK-22886. --- .../spark/ml/recommendation/ALSSuite.scala | 213 ++++++++++++------ .../apache/spark/ml/util/MLTestingUtils.scala | 44 ---- 2 files changed, 143 insertions(+), 114 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index addcd21d50aac..e3dfe2faf5698 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,8 +22,7 @@ import java.util.Random import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.WrappedArray +import scala.collection.mutable.{ArrayBuffer, WrappedArray} import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -35,21 +34,20 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.recommendation.ALS.Rating -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class ALSSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { +class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() @@ -413,34 +411,36 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => - (rating.toDouble, prediction.toDouble) + testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble)) + + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the + // weighted RMSE with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val errorSquares = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + } + val mse = errorSquares.sum / errorSquares.length + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) } - val rmse = - if (implicitPrefs) { - // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. - // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE - // with the confidence scores as weights. - val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => - val confidence = 1.0 + alpha * math.abs(rating) - val rating01 = math.max(math.min(rating, 1.0), 0.0) - val prediction01 = math.max(math.min(prediction, 1.0), 0.0) - val err = prediction01 - rating01 - (confidence, confidence * err * err) - }.reduce { case ((c0, e0), (c1, e1)) => - (c0 + c1, e0 + e1) - } - math.sqrt(weightedSumSq / totalWeight) - } else { - val mse = predictions.map { case (rating, prediction) => - val err = rating - prediction - err * err - }.mean() - math.sqrt(mse) - } - logInfo(s"Test RMSE is $rmse.") - assert(rmse < targetRMSE) MLTestingUtils.checkCopyAndUids(als, model) } @@ -586,6 +586,68 @@ class ALSSuite allModelParamSettings, checkModelData) } + private def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val df = dfs.find { + case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType + } match { + case Some((_, df)) => df + } + val expected = estimator.fit(df) + val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } + + val baseDF = dfs.find(_._1.numericType == baseType).get._2 + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + private class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Int, Double)]) + + private def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String) = { + + import testImplicits._ + + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col) + val types = + Seq(new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()) + ) + types.map { t => + val cols = Seq(col(column).cast(t.numericType)) ++ others + t -> df.select(cols: _*) + } + } + test("input type validation") { val spark = this.spark import spark.implicits._ @@ -595,12 +657,16 @@ class ALSSuite val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => - ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) - } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1) + } { (ex, act, df, enc) => + val expected = ex.transform(df).selectExpr("prediction") + .first().getFloat(0) + testTransformerByGlobalCheckFunc(df, act, "prediction") { + case rows: Seq[Row] => + expected ~== rows.head.getFloat(0) absTol 1e-6 + }(enc) } } // check user/item ids falling outside of Int range @@ -628,18 +694,22 @@ class ALSSuite } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) - assert(intercept[SparkException] { - model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains(msg)) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + assert(intercept[SparkException] { + model.transform(dataFrame).first + }.getMessage.contains(msg)) + assert(intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + }.getMessage.contains(msg)) + } + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } } @@ -662,28 +732,31 @@ class ALSSuite val knownItem = data.select(max("item")).as[Int].first() val unknownItem = knownItem + 20 val test = Seq( - (unknownUser, unknownItem), - (knownUser, unknownItem), - (unknownUser, knownItem), - (knownUser, knownItem) - ).toDF("user", "item") + (unknownUser, unknownItem, true), + (knownUser, unknownItem, true), + (unknownUser, knownItem, true), + (knownUser, knownItem, false) + ).toDF("user", "item", "expectedIsNaN") val als = new ALS().setMaxIter(1).setRank(1) // default is 'nan' val defaultModel = als.fit(data) - val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() - assert(defaultPredictions.length == 4) - assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) - assert(!defaultPredictions.last.isNaN) + testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") { + case Row(expectedIsNaN: Boolean, prediction: Float) => + assert(prediction.isNaN === expectedIsNaN) + } // check 'drop' strategy should filter out rows with unknown users/items - val dropPredictions = defaultModel - .setColdStartStrategy("drop") - .transform(test) - .select("prediction").as[Float].collect() - assert(dropPredictions.length == 1) - assert(!dropPredictions.head.isNaN) - assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + val defaultPrediction = defaultModel.transform(test).select("prediction") + .as[Float].filter(!_.isNaN).first() + testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test, + defaultModel.setColdStartStrategy("drop"), "prediction") { + case rows: Seq[Row] => + val dropPredictions = rows.map(_.getFloat(0)) + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPrediction relTol 1e-14) + } } test("case insensitive cold start param value") { @@ -693,7 +766,7 @@ class ALSSuite val data = ratings.toDF val model = new ALS().fit(data) Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => - model.setColdStartStrategy(s).transform(data) + testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index aef81c8c173a0..c328d81b4bc3a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite { } } - def checkNumericTypesALS( - estimator: ALS, - spark: SparkSession, - column: String, - baseType: NumericType) - (check: (ALSModel, ALSModel) => Unit) - (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(spark, column) - val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) - actuals.foreach { case (_, actual) => check(expected, actual) } - actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } - - val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) - val cols = Seq(col(column).cast(StringType)) ++ others - val strDF = baseDF.select(cols: _*) - val thrown = intercept[IllegalArgumentException] { - estimator.fit(strDF) - } - assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) - } - def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } - def genRatingsDFWithNumericCols( - spark: SparkSession, - column: String): Map[NumericType, DataFrame] = { - val df = spark.createDataFrame(Seq( - (0, 10, 1.0), - (1, 20, 2.0), - (2, 30, 3.0), - (3, 40, 4.0), - (4, 50, 5.0) - )).toDF("user", "item", "rating") - - val others = df.columns.toSeq.diff(Seq(column)).map(col) - val types: Seq[NumericType] = - Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map { t => - val cols = Seq(col(column).cast(t)) ++ others - t -> df.select(cols: _*) - }.toMap - } - def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", From b308182f233b8840dfe0e6b5736d2f2746f40757 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 26 Feb 2018 08:39:44 -0800 Subject: [PATCH 084/593] [SPARK-23438][DSTREAMS] Fix DStreams data loss with WAL when driver crashes ## What changes were proposed in this pull request? There is a race condition introduced in SPARK-11141 which could cause data loss. The problem is that ReceivedBlockTracker.insertAllocatedBatch function assumes that all blocks from streamIdToUnallocatedBlockQueues allocated to the batch and clears the queue. In this PR only the allocated blocks will be removed from the queue which will prevent data loss. ## How was this patch tested? Additional unit test + manually. Author: Gabor Somogyi Closes #20620 from gaborgsomogyi/SPARK-23438. --- .../scheduler/ReceivedBlockTracker.scala | 11 +++++---- .../streaming/ReceivedBlockTrackerSuite.scala | 23 ++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 5d9a8ac0d9297..dacff69d55dd2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker( getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo } - // Insert the recovered block-to-batch allocations and clear the queue of received blocks - // (when the blocks were originally allocated to the batch, the queue must have been cleared). + // Insert the recovered block-to-batch allocations and removes them from queue of + // received blocks. def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") - streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + allocatedBlocks.streamIdToAllocatedBlocks.foreach { + case (streamId, allocatedBlocksInStream) => + getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet) + } timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } @@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { + private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { logTrace(s"Writing record: $record") try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 107c3f5dcc08d..4fa236bd39663 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult -import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _} import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } + test("recovery with write ahead logs should remove only allocated blocks from received queue") { + val manualClock = new ManualClock + val batchTime = manualClock.getTimeMillis() + + val tracker1 = createTracker(clock = manualClock) + tracker1.isWriteAheadLogEnabled should be (true) + + val allocatedBlockInfos = generateBlockInfos() + val unallocatedBlockInfos = generateBlockInfos() + val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos + receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) } + val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos)) + tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + tracker1.stop() + + val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) + tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks + tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates From 185f5bc7dd52cebe8fac9393ecb2bd0968bc5867 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Mon, 26 Feb 2018 10:28:45 -0800 Subject: [PATCH 085/593] [SPARK-23449][K8S] Preserve extraJavaOptions ordering For some JVM options, like `-XX:+UnlockExperimentalVMOptions` ordering is necessary. ## What changes were proposed in this pull request? Keep original `extraJavaOptions` ordering, when passing them through environment variables inside the Docker container. ## How was this patch tested? Ran base branch a couple of times and checked startup command in logs. Ordering differed every time. Added sorting, ordering was consistent to what user had in `extraJavaOptions`. Author: Andrew Korzhuev Closes #20628 from andrusha/patch-2. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index b9090dc2852a5..3d67b0a702dd4 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -41,7 +41,7 @@ fi shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" -env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" From 7ec83658fbc88505dfc2d8a6f76e90db747f1292 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 26 Feb 2018 11:28:44 -0800 Subject: [PATCH 086/593] [SPARK-23491][SS] Remove explicit job cancellation from ContinuousExecution reconfiguring ## What changes were proposed in this pull request? Remove queryExecutionThread.interrupt() from ContinuousExecution. As detailed in the JIRA, interrupting the thread is only relevant in the microbatch case; for continuous processing the query execution can quickly clean itself up without. ## How was this patch tested? existing tests Author: Jose Torres Closes #20622 from jose-torres/SPARK-23441. --- .../streaming/continuous/ContinuousExecution.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2c1d6c509d21b..daebd1dd010ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -236,9 +236,7 @@ class ContinuousExecution( startTrigger() if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { - stopSources() if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() } false @@ -266,12 +264,20 @@ class ContinuousExecution( SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } + } catch { + case t: Throwable + if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING => + logInfo(s"Query $id ignoring exception from reconfiguring: $t") + // interrupted by reconfiguration - swallow exception so we can restart the query } finally { epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() epochUpdateThread.join() + + stopSources() + sparkSession.sparkContext.cancelJobGroup(runId.toString) } } From 8077bb04f350fd35df83ef896135c0672dc3f7b0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Mon, 26 Feb 2018 23:37:31 -0800 Subject: [PATCH 087/593] [SPARK-23445] ColumnStat refactoring ## What changes were proposed in this pull request? Refactor ColumnStat to be more flexible. * Split `ColumnStat` and `CatalogColumnStat` just like `CatalogStatistics` is split from `Statistics`. This detaches how the statistics are stored from how they are processed in the query plan. `CatalogColumnStat` keeps `min` and `max` as `String`, making it not depend on dataType information. * For `CatalogColumnStat`, parse column names from property names in the metastore (`KEY_VERSION` property), not from metastore schema. This means that `CatalogColumnStat`s can be created for columns even if the schema itself is not stored in the metastore. * Make all fields optional. `min`, `max` and `histogram` for columns were optional already. Having them all optional is more consistent, and gives flexibility to e.g. drop some of the fields through transformations if they are difficult / impossible to calculate. The added flexibility will make it possible to have alternative implementations for stats, and separates stats collection from stats and estimation processing in plans. ## How was this patch tested? Refactored existing tests to work with refactored `ColumnStat` and `CatalogColumnStat`. New tests added in `StatisticsSuite` checking that backwards / forwards compatibility is not broken. Author: Juliusz Sompolski Closes #20624 from juliuszsompolski/SPARK-23445. --- .../sql/catalyst/catalog/interface.scala | 146 ++++++++- .../optimizer/StarSchemaDetection.scala | 6 +- .../catalyst/plans/logical/Statistics.scala | 256 ++-------------- .../statsEstimation/AggregateEstimation.scala | 6 +- .../statsEstimation/EstimationUtils.scala | 20 +- .../statsEstimation/FilterEstimation.scala | 98 +++--- .../statsEstimation/JoinEstimation.scala | 55 ++-- .../catalyst/optimizer/JoinReorderSuite.scala | 25 +- .../StarJoinCostBasedReorderSuite.scala | 96 ++---- .../optimizer/StarJoinReorderSuite.scala | 77 ++--- .../AggregateEstimationSuite.scala | 24 +- .../BasicStatsEstimationSuite.scala | 12 +- .../FilterEstimationSuite.scala | 279 +++++++++--------- .../statsEstimation/JoinEstimationSuite.scala | 138 +++++---- .../ProjectEstimationSuite.scala | 70 +++-- .../StatsEstimationTestBase.scala | 10 +- .../command/AnalyzeColumnCommand.scala | 138 ++++++++- .../spark/sql/execution/command/tables.scala | 9 +- .../spark/sql/StatisticsCollectionSuite.scala | 9 +- .../sql/StatisticsCollectionTestBase.scala | 168 +++++++++-- .../spark/sql/hive/HiveExternalCatalog.scala | 63 ++-- .../spark/sql/hive/StatisticsSuite.scala | 162 +++------- 22 files changed, 995 insertions(+), 872 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 95b6fbb0cd61a..f3e67dc4e975c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -21,7 +21,9 @@ import java.net.URI import java.util.Date import scala.collection.mutable +import scala.util.control.NonFatal +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** @@ -361,7 +363,7 @@ object CatalogTable { case class CatalogStatistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - colStats: Map[String, ColumnStat] = Map.empty) { + colStats: Map[String, CatalogColumnStat] = Map.empty) { /** * Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based @@ -369,7 +371,8 @@ case class CatalogStatistics( */ def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = { if (cboEnabled && rowCount.isDefined) { - val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _))) + val attrStats = AttributeMap(planOutput + .flatMap(a => colStats.get(a.name).map(a -> _.toPlanStat(a.name, a.dataType)))) // Estimate size as number of rows * row size. val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats) Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats) @@ -387,6 +390,143 @@ case class CatalogStatistics( } } +/** + * This class of statistics for a column is used in [[CatalogTable]] to interact with metastore. + */ +case class CatalogColumnStat( + distinctCount: Option[BigInt] = None, + min: Option[String] = None, + max: Option[String] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, + histogram: Option[Histogram] = None) { + + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the column and name of the field (e.g. "colName.distinctCount"), + * and the value is the string representation for the value. + * min/max values are stored as Strings. They can be deserialized using + * [[CatalogColumnStat.fromExternalString]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * Any of the fields that are null (None) won't appear in the map. + */ + def toMap(colName: String): Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", "1") + distinctCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString) + } + nullCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString) + } + avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) } + maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) } + min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) } + max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) } + histogram.foreach { h => + map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h)) + } + map.toMap + } + + /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */ + def toPlanStat( + colName: String, + dataType: DataType): ColumnStat = + ColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) +} + +object CatalogColumnStat extends Logging { + + // List of string keys used to serialize CatalogColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" + + /** + * Converts from string representation of data type to the corresponding Catalyst data type. + */ + def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + + /** + * Creates a [[CatalogColumnStat]] object from the given map. + * This is used to deserialize column stats from some external storage. + * The serialization side is defined in [[CatalogColumnStat.toMap]]. + */ + def fromMap( + table: String, + colName: String, + map: Map[String, String]): Option[CatalogColumnStat] = { + + try { + Some(CatalogColumnStat( + distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)), + min = map.get(s"${colName}.${KEY_MIN_VALUE}"), + max = map.get(s"${colName}.${KEY_MAX_VALUE}"), + nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)), + avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong), + maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong), + histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize) + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e) + None + } + } +} + case class CatalogTableType private(name: String) object CatalogTableType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 1f20b7661489e..2aa762e2595ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -187,11 +187,11 @@ object StarSchemaDetection extends PredicateHelper { stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { + val colStats = stats.attributeStats.get(col).get + if (!colStats.hasCountStats || colStats.nullCount.get > 0) { false } else { - val distinctCount = colStats.get.distinctCount + val distinctCount = colStats.distinctCount.get val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) // ndvMaxErr adjusted based on TPCDS 1TB data results relDiff <= conf.ndvMaxError * 2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 96b199d7f20b0..b3a48860aa63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -27,6 +27,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} @@ -79,11 +80,10 @@ case class Statistics( /** * Statistics collected for a column. * - * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * 1. The JVM data type stored in min/max is the internal data type for the corresponding * Catalyst data type. For example, the internal type of DateType is Int, and that the internal * type of TimestampType is Long. - * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -95,240 +95,32 @@ case class Statistics( * @param histogram histogram of the values */ case class ColumnStat( - distinctCount: BigInt, - min: Option[Any], - max: Option[Any], - nullCount: BigInt, - avgLen: Long, - maxLen: Long, + distinctCount: Option[BigInt] = None, + min: Option[Any] = None, + max: Option[Any] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, histogram: Option[Histogram] = None) { - // We currently don't store min/max for binary/string type. This can change in the future and - // then we need to remove this require. - require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) - require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - - /** - * Returns a map from string to string that can be used to serialize the column stats. - * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. min/max values are converted to the external data type. For - * example, for DateType we store java.sql.Date, and for TimestampType we store - * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. - * - * As part of the protocol, the returned map always contains a key called "version". - * In the case min/max values are null (None), they won't appear in the map. - */ - def toMap(colName: String, dataType: DataType): Map[String, String] = { - val map = new scala.collection.mutable.HashMap[String, String] - map.put(ColumnStat.KEY_VERSION, "1") - map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) - map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) - map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) - map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } - histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } - map.toMap - } - - /** - * Converts the given value from Catalyst data type to string representation of external - * data type. - */ - private def toExternalString(v: Any, colName: String, dataType: DataType): String = { - val externalValue = dataType match { - case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) - case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) - case BooleanType | _: IntegralType | FloatType | DoubleType => v - case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal - // This version of Spark does not use min/max for binary/string types so we ignore it. - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $colName of data type: $dataType.") - } - externalValue.toString - } - -} + // Are distinctCount and nullCount statistics defined? + val hasCountStats = distinctCount.isDefined && nullCount.isDefined + // Are min and max statistics defined? + val hasMinMaxStats = min.isDefined && max.isDefined -object ColumnStat extends Logging { - - // List of string keys used to serialize ColumnStat - val KEY_VERSION = "version" - private val KEY_DISTINCT_COUNT = "distinctCount" - private val KEY_MIN_VALUE = "min" - private val KEY_MAX_VALUE = "max" - private val KEY_NULL_COUNT = "nullCount" - private val KEY_AVG_LEN = "avgLen" - private val KEY_MAX_LEN = "maxLen" - private val KEY_HISTOGRAM = "histogram" - - /** Returns true iff the we support gathering column statistics on column of the given type. */ - def supportsType(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case BooleanType => true - case DateType => true - case TimestampType => true - case BinaryType | StringType => true - case _ => false - } - - /** Returns true iff the we support gathering histogram on column of the given type. */ - def supportsHistogram(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case DateType => true - case TimestampType => true - case _ => false - } - - /** - * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats - * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. - */ - def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { - try { - Some(ColumnStat( - distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), - // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - nullCount = BigInt(map(KEY_NULL_COUNT).toLong), - avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, - histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) - )) - } catch { - case NonFatal(e) => - logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) - None - } - } - - /** - * Converts from string representation of external data type to the corresponding Catalyst data - * type. - */ - private def fromExternalString(s: String, name: String, dataType: DataType): Any = { - dataType match { - case BooleanType => s.toBoolean - case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) - case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) - case ByteType => s.toByte - case ShortType => s.toShort - case IntegerType => s.toInt - case LongType => s.toLong - case FloatType => s.toFloat - case DoubleType => s.toDouble - case _: DecimalType => Decimal(s) - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $name of data type: $dataType.") - } - } - - /** - * Constructs an expression to compute column statistics for a given column. - * - * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, - * distinctCountsForIntervals: Array[Long] - * - * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and - * as a result should stay in sync with it. - */ - def statExprs( - col: Attribute, - conf: SQLConf, - colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { - def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => - expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } - }) - val one = Literal(1, LongType) - - // the approximate ndv (num distinct value) should never be larger than the number of rows - val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) - val numNulls = Subtract(Count(one), numNonNulls) - val defaultSize = Literal(col.dataType.defaultSize, LongType) - val nullArray = Literal(null, ArrayType(LongType)) - - def fixedLenTypeStruct: CreateNamedStruct = { - val genHistogram = - ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) - val intervalNdvsExpr = if (genHistogram) { - ApproxCountDistinctForIntervals(col, - Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) - } else { - nullArray - } - // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, - defaultSize, defaultSize, intervalNdvsExpr) - } - - col.dataType match { - case _: IntegralType => fixedLenTypeStruct - case _: DecimalType => fixedLenTypeStruct - case DoubleType | FloatType => fixedLenTypeStruct - case BooleanType => fixedLenTypeStruct - case DateType => fixedLenTypeStruct - case TimestampType => fixedLenTypeStruct - case BinaryType | StringType => - // For string and binary type, we don't compute min, max or histogram - val nullLit = Literal(null, col.dataType) - struct( - ndv, nullLit, nullLit, numNulls, - // Set avg/max size to default size if all the values are null or there is no value. - Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), - nullArray) - case _ => - throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") - } - } - - /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ - def rowToColumnStat( - row: InternalRow, - attr: Attribute, - rowCount: Long, - percentiles: Option[ArrayData]): ColumnStat = { - // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. - val cs = ColumnStat( - distinctCount = BigInt(row.getLong(0)), - // for string/binary min/max, get should return null - min = Option(row.get(1, attr.dataType)), - max = Option(row.get(2, attr.dataType)), - nullCount = BigInt(row.getLong(3)), - avgLen = row.getLong(4), - maxLen = row.getLong(5) - ) - if (row.isNullAt(6)) { - cs - } else { - val ndvs = row.getArray(6).toLongArray() - assert(percentiles.get.numElements() == ndvs.length + 1) - val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) - // Construct equi-height histogram - val bins = ndvs.zipWithIndex.map { case (ndv, i) => - HistogramBin(endpoints(i), endpoints(i + 1), ndv) - } - val nonNullRows = rowCount - cs.nullCount - val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) - cs.copy(histogram = Some(histogram)) - } - } + // Are avgLen and maxLen statistics defined? + val hasLenStats = avgLen.isDefined && maxLen.isDefined + def toCatalogColumnStat(colName: String, dataType: DataType): CatalogColumnStat = + CatalogColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index c41fac4015ec0..111c594a53e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -32,13 +32,15 @@ object AggregateEstimation { val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => - e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) + e.isInstanceOf[Attribute] && + childStats.attributeStats.get(e.asInstanceOf[Attribute]).exists(_.hasCountStats) } if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( - (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) + (res, expr) => res * + childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get) outputRows = if (agg.groupingExpressions.isEmpty) { // If there's no group-by columns, the output is a single row containing values of aggregate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index d793f77413d18..0f147f0ffb135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -38,9 +39,18 @@ object EstimationUtils { } } + /** Check if each attribute has column stat containing distinct and null counts + * in the corresponding statistic. */ + def columnStatsWithCountsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.get(attr).map(_.hasCountStats).getOrElse(false) + } + } + + /** Statistics for a Column containing only NULLs. */ def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, - avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(rowCount), + avgLen = Some(dataType.defaultSize), maxLen = Some(dataType.defaultSize)) } /** @@ -70,13 +80,13 @@ object EstimationUtils { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => - if (attrStats.contains(attr)) { + if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 + attrStats(attr).avgLen.get + 8 + 4 case _ => - attrStats(attr).avgLen + attrStats(attr).avgLen.get } } else { attr.dataType.defaultSize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4cc32de2d32d7..0538c9d88584b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, isNull: Boolean, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) { logDebug("[CBO] No statistics for " + attr) return None } @@ -234,14 +234,14 @@ case class FilterEstimation(plan: Filter) extends Logging { val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble + (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble } if (update) { val newStats = if (isNull) { - colStat.copy(distinctCount = 0, min = None, max = None) + colStat.copy(distinctCount = Some(0), min = None, max = None) } else { - colStat.copy(nullCount = 0) + colStat.copy(nullCount = Some(0)) } colStatsMap.update(attr, newStats) } @@ -322,17 +322,21 @@ case class FilterEstimation(plan: Filter) extends Logging { // value. val newStats = attr.dataType match { case StringType | BinaryType => - colStat.copy(distinctCount = 1, nullCount = 0) + colStat.copy(distinctCount = Some(1), nullCount = Some(0)) case _ => - colStat.copy(distinctCount = 1, min = Some(literal.value), - max = Some(literal.value), nullCount = 0) + colStat.copy(distinctCount = Some(1), min = Some(literal.value), + max = Some(literal.value), nullCount = Some(0)) } colStatsMap.update(attr, newStats) } if (colStat.histogram.isEmpty) { - // returns 1/ndv if there is no histogram - Some(1.0 / colStat.distinctCount.toDouble) + if (!colStat.distinctCount.isEmpty) { + // returns 1/ndv if there is no histogram + Some(1.0 / colStat.distinctCount.get.toDouble) + } else { + None + } } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -378,13 +382,13 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, hSet: Set[Any], update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.hasDistinctCount(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount + val ndv = colStat.distinctCount.get val dataType = attr.dataType var newNdv = ndv @@ -407,8 +411,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), - max = Some(newMax), nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), + max = Some(newMax), nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -416,7 +420,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case StringType | BinaryType => newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) colStatsMap.update(attr, newStats) } } @@ -443,12 +447,17 @@ case class FilterEstimation(plan: Filter) extends Logging { literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] val max = statsInterval.max val min = statsInterval.min - val ndv = colStat.distinctCount.toDouble + val ndv = colStat.distinctCount.get.toDouble // determine the overlapping degree between predicate interval and column's interval val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) @@ -520,8 +529,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = colStat.copy(distinctCount = ceil(ndv * percent), - min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)), + min = newMin, max = newMax, nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -637,11 +646,11 @@ case class FilterEstimation(plan: Filter) extends Logging { attrRight: Attribute, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attrLeft)) { + if (!colStatsMap.hasCountStats(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) return None } - if (!colStatsMap.contains(attrRight)) { + if (!colStatsMap.hasCountStats(attrRight)) { logDebug("[CBO] No statistics for " + attrRight) return None } @@ -668,7 +677,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val minRight = statsIntervalRight.min // determine the overlapping degree between predicate interval and column's interval - val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val allNotNull = (colStatLeft.nullCount.get == 0) && (colStatRight.nullCount.get == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right // - no overlap: @@ -707,14 +716,14 @@ case class FilterEstimation(plan: Filter) extends Logging { case _: EqualTo => ((maxLeft < minRight) || (maxRight < minLeft), (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) case _: EqualNullSafe => // For null-safe equality, we use a very restrictive condition to evaluate its overlap. // If null values exists, we set it to partial overlap. (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) } @@ -731,9 +740,9 @@ case class FilterEstimation(plan: Filter) extends Logging { if (update) { // Need to adjust new min/max after the filter condition is applied - val ndvLeft = BigDecimal(colStatLeft.distinctCount) + val ndvLeft = BigDecimal(colStatLeft.distinctCount.get) val newNdvLeft = ceil(ndvLeft * percent) - val ndvRight = BigDecimal(colStatRight.distinctCount) + val ndvRight = BigDecimal(colStatRight.distinctCount.get) val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max @@ -817,10 +826,10 @@ case class FilterEstimation(plan: Filter) extends Logging { } } - val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + val newStatsLeft = colStatLeft.copy(distinctCount = Some(newNdvLeft), min = newMinLeft, max = newMaxLeft) colStatsMap(attrLeft) = newStatsLeft - val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + val newStatsRight = colStatRight.copy(distinctCount = Some(newNdvRight), min = newMinRight, max = newMaxRight) colStatsMap(attrRight) = newStatsRight } @@ -849,17 +858,35 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) /** - * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in - * originalMap, because updatedMap has the latest (updated) column stats. + * Gets an Option of column stat for the given attribute. + * Prefer the column stat in updatedMap than that in originalMap, + * because updatedMap has the latest (updated) column stats. */ - def apply(a: Attribute): ColumnStat = { + def get(a: Attribute): Option[ColumnStat] = { if (updatedMap.contains(a.exprId)) { - updatedMap(a.exprId)._2 + updatedMap.get(a.exprId).map(_._2) } else { - originalMap(a) + originalMap.get(a) } } + def hasCountStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + def hasDistinctCount(a: Attribute): Boolean = + get(a).map(_.distinctCount.isDefined).getOrElse(false) + + def hasMinMaxStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + /** + * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in + * originalMap, because updatedMap has the latest (updated) column stats. + */ + def apply(a: Attribute): ColumnStat = { + get(a).get + } + /** Updates column stats in updatedMap. */ def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) @@ -871,11 +898,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) - val newNdv = if (colStat.distinctCount > 1) { + val newNdv = if (colStat.distinctCount.isEmpty) { + // No NDV in the original stats. + None + } else if (colStat.distinctCount.get > 1) { // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows // decreases; otherwise keep it unchanged. - EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get)) } else { // no need to scale down since it is already down to 1 (for skewed distribution case) colStat.distinctCount diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f0294a4246703..2543e38a92c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -85,7 +85,8 @@ case class JoinEstimation(join: Join) extends Logging { // 3. Update statistics based on the output of join val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) - val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val attributesWithStat = join.output.filter(a => + inputAttrStats.get(a).map(_.hasCountStats).getOrElse(false)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { @@ -106,10 +107,10 @@ case class JoinEstimation(join: Join) extends Logging { case FullOuter => fromLeft.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + rightRows))) } ++ fromRight.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + leftRows))) } case _ => assert(joinType == Inner || joinType == Cross) @@ -219,19 +220,27 @@ case class JoinEstimation(join: Join) extends Logging { private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, - newMin: Option[Any], - newMax: Option[Any]): (BigInt, ColumnStat) = { + min: Option[Any], + max: Option[Any]): (BigInt, ColumnStat) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + val maxNdv = leftKeyStat.distinctCount.get.max(rightKeyStat.distinctCount.get) // Compute cardinality by the basic formula. val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) // Get the intersected column stat. - val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + val newNdv = Some(leftKeyStat.distinctCount.get.min(rightKeyStat.distinctCount.get)) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(newNdv, min, max, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -267,9 +276,17 @@ case class JoinEstimation(join: Join) extends Logging { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(Some(ceil(totalNdv)), newMin, newMax, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -292,10 +309,14 @@ case class JoinEstimation(join: Join) extends Logging { } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount - val newNdv = if (join.left.outputSet.contains(a)) { - updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv) + val newNdv = if (oldNdv.isDefined) { + Some(if (join.left.outputSet.contains(a)) { + updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get) + } else { + updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get) + }) } else { - updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv) + None } val newColStat = oldColStat.copy(distinctCount = newNdv) // TODO: support nullCount updates for specific outer joins @@ -313,7 +334,7 @@ case class JoinEstimation(join: Join) extends Logging { // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to // support it in the future by using `nullCount` in column stats. case (lk: AttributeReference, rk: AttributeReference) - if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) + if columnStatsWithCountsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 2fb587d50a4cb..565b0a10154a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -62,24 +62,15 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } - /** Set up tables and columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("t1.k-1-2") -> rangeColumnStat(2, 0), + attr("t1.v-1-10") -> rangeColumnStat(10, 0), + attr("t2.k-1-5") -> rangeColumnStat(5, 0), + attr("t3.v-1-100") -> rangeColumnStat(100, 0), + attr("t4.k-1-2") -> rangeColumnStat(2, 0), + attr("t4.v-1-10") -> rangeColumnStat(10, 0), + attr("t5.k-1-5") -> rangeColumnStat(5, 0), + attr("t5.v-1-5") -> rangeColumnStat(5, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index ada6e2a43ea0f..d4d23ad69b2c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -68,88 +68,56 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 (fact table) - attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(100, 0), + attr("f1_fk2") -> rangeColumnStat(100, 0), + attr("f1_fk3") -> rangeColumnStat(100, 0), + attr("f1_c1") -> rangeColumnStat(100, 0), + attr("f1_c2") -> rangeColumnStat(100, 0), // D1 (dimension) - attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk") -> rangeColumnStat(100, 0), + attr("d1_c2") -> rangeColumnStat(50, 0), + attr("d1_c3") -> rangeColumnStat(50, 0), // D2 (dimension) - attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_pk") -> rangeColumnStat(20, 0), + attr("d2_c2") -> rangeColumnStat(10, 0), + attr("d2_c3") -> rangeColumnStat(10, 0), // D3 (dimension) - attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk") -> rangeColumnStat(10, 0), + attr("d3_c2") -> rangeColumnStat(5, 0), + attr("d3_c3") -> rangeColumnStat(5, 0), // T1 (regular table i.e. outside star) - attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c1") -> rangeColumnStat(20, 1), + attr("t1_c2") -> rangeColumnStat(10, 1), + attr("t1_c3") -> rangeColumnStat(10, 1), // T2 (regular table) - attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c1") -> rangeColumnStat(5, 1), + attr("t2_c2") -> rangeColumnStat(5, 1), + attr("t2_c3") -> rangeColumnStat(5, 1), // T3 (regular table) - attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c1") -> rangeColumnStat(5, 1), + attr("t3_c2") -> rangeColumnStat(5, 1), + attr("t3_c3") -> rangeColumnStat(5, 1), // T4 (regular table) - attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c1") -> rangeColumnStat(5, 1), + attr("t4_c2") -> rangeColumnStat(5, 1), + attr("t4_c3") -> rangeColumnStat(5, 1), // T5 (regular table) - attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c1") -> rangeColumnStat(5, 1), + attr("t5_c2") -> rangeColumnStat(5, 1), + attr("t5_c3") -> rangeColumnStat(5, 1), // T6 (regular table) - attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4) + attr("t6_c1") -> rangeColumnStat(5, 1), + attr("t6_c2") -> rangeColumnStat(5, 1), + attr("t6_c3") -> rangeColumnStat(5, 1) )) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 777c5637201ed..4e0883e91e84a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -70,59 +70,40 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // Tables' cardinality: f1 > d3 > d1 > d2 > s3 private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 - attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(3, 0), + attr("f1_fk2") -> rangeColumnStat(3, 0), + attr("f1_fk3") -> rangeColumnStat(4, 0), + attr("f1_c4") -> rangeColumnStat(4, 0), // D1 - attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk1") -> rangeColumnStat(4, 0), + attr("d1_c2") -> rangeColumnStat(3, 0), + attr("d1_c3") -> rangeColumnStat(4, 0), + attr("d1_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D2 - attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = Some(3), min = Some("1"), max = Some("3"), + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)), + attr("d2_pk1") -> rangeColumnStat(3, 0), + attr("d2_c3") -> rangeColumnStat(3, 0), + attr("d2_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D3 - attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_fk1") -> rangeColumnStat(3, 0), + attr("d3_c2") -> rangeColumnStat(3, 0), + attr("d3_pk1") -> rangeColumnStat(5, 0), + attr("d3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // S3 - attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_pk1") -> rangeColumnStat(2, 0), + attr("s3_c2") -> rangeColumnStat(1, 0), + attr("s3_c3") -> rangeColumnStat(1, 0), + attr("s3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // F11 - attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("f11_fk1") -> rangeColumnStat(3, 0), + attr("f11_fk2") -> rangeColumnStat(3, 0), + attr("f11_fk3") -> rangeColumnStat(4, 0), + attr("f11_c4") -> rangeColumnStat(4, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 23f95a6cc2ac2..8213d568fe85e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -29,16 +29,16 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key11") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key12") -> ColumnStat(distinctCount = Some(4), min = Some(10), max = Some(40), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key21") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -63,8 +63,8 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { tableRowCount = 6, groupByColumns = Seq("key21", "key22"), // Row count = product of ndv - expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2 - .distinctCount) + expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get * + nameToColInfo("key22")._2.distinctCount.get) } test("empty group-by column") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 7d532ff343178..953094cb0dd52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.IntegerType class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") - val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStat = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val plan = StatsTestPlan( outputList = Seq(attribute), @@ -116,13 +116,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2b1fe987a7960..43440d51dede6 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() - val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() - val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1) + val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)) // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() - val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatDate = ColumnStat(distinctCount = Some(10), + min = Some(dMin), max = Some(dMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = Decimal("0.200000000000000000") val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDecimal = ColumnStat(distinctCount = Some(4), + min = Some(decMin), max = Some(decMax), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() - val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() - val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2) + val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)) // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint < cint2 val attrInt2 = AttributeReference("cint2", IntegerType)() - val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint = cint3 without overlap at all. val attrInt3 = AttributeReference("cint3", IntegerType)() - val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint4 has values in the range from 1 to 10 // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 // This column is created to test complete overlap val attrInt4 = AttributeReference("cint4", IntegerType)() - val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. // Note that cintHgm has an even distribution with histogram information built. @@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) - val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt)) // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. // Note that cintSkewHgm has a skewed distribution with histogram information built. @@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) - val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew)) val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, @@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 3)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } @@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } @@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 2)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))), expectedRowCount = 2) } @@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 6)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))), expectedRowCount = 6) } @@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 5)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -342,47 +344,47 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 9), - attrString -> colStatString.copy(distinctCount = 9)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)), + attrString -> colStatString.copy(distinctCount = Some(9))), expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 7)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } @@ -391,18 +393,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(1), + min = Some(d20170102), max = Some(d20170102), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { + val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170101), max = Some(d20170103), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -414,8 +419,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170103), max = Some(d20170105), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -424,42 +430,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(1), + min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 1) } test("cdecimal < 0.60 ") { + val dec_0_20 = Decimal("0.200000000000000000") val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(3), + min = Some(dec_0_20), max = Some(dec_0_60), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 1) } test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 10) } @@ -468,8 +477,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. // The predicate is: column IN (1, 2, 3, 4, 5). - val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildColStatInt = ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val cornerChildStatsTestplan = StatsTestPlan( outputList = Seq(attrInt), rowCount = 2L, @@ -477,16 +487,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) validateEstimatedStats( Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) } // This is a limitation test. We should remove it after the limitation is removed. test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { val attrIntLargerRange = AttributeReference("c1", IntegerType)() - val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 10, avgLen = 4, maxLen = 4) + val colStatIntLargerRange = ColumnStat(distinctCount = Some(20), + min = Some(1), max = Some(20), + nullCount = Some(10), avgLen = Some(4), maxLen = Some(4)) val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) val largerTable = StatsTestPlan( outputList = Seq(attrIntLargerRange), @@ -508,10 +519,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -519,10 +530,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -530,10 +541,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -541,10 +552,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // complete overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -552,10 +563,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -571,10 +582,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // all table records qualify. validateEstimatedStats( Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -592,11 +603,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), Seq( - attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - attrString -> colStatString.copy(distinctCount = 5)), + attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrString -> colStatString.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -606,15 +617,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cintHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 1) } @@ -629,8 +640,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } @@ -645,16 +656,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } test("cintHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -669,8 +680,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 5) } @@ -679,8 +690,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -688,7 +699,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -698,15 +709,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))), expectedRowCount = 9) } test("cintSkewHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 4) } @@ -721,8 +732,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintSkewHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -738,16 +749,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } test("cintSkewHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -764,8 +775,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 3) } @@ -774,8 +785,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 8) } @@ -783,7 +794,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))), expectedRowCount = 3) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 26139d85d25fb..12c0a7be21292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def estimateByHistogram( leftHistogram: Histogram, rightHistogram: Histogram, - expectedMin: Double, - expectedMax: Double, + expectedMin: Any, + expectedMax: Any, expectedNdv: Long, expectedRows: Long): Unit = { val col1 = attr("key1") @@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRows), attributeStats = AttributeMap(Seq( col1 -> c1.stats.attributeStats(col1).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)), col2 -> c2.stats.attributeStats(col2).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)))) ) // Join order should not affect estimation result. @@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def generateJoinChild( col: Attribute, histogram: Histogram, - expectedMin: Double, - expectedMax: Double): LogicalPlan = { - val colStat = inferColumnStat(histogram) + expectedMin: Any, + expectedMax: Any): LogicalPlan = { + val colStat = inferColumnStat(histogram, expectedMin, expectedMax) StatsTestPlan( outputList = Seq(col), rowCount = (histogram.height * histogram.bins.length).toLong, @@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } /** Column statistics should be consistent with histograms in tests. */ - private def inferColumnStat(histogram: Histogram): ColumnStat = { + private def inferColumnStat( + histogram: Histogram, + expectedMin: Any, + expectedMax: Any): ColumnStat = { + var ndv = 0L for (i <- histogram.bins.indices) { val bin = histogram.bins(i) @@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { ndv += bin.ndv } } - ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), - max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + ColumnStat(distinctCount = Some(ndv), + min = Some(expectedMin), max = Some(expectedMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(histogram)) } @@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(5 + 3), attributeStats = AttributeMap( // Update null count in column stats. - Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), - nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), - nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), - nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5))))) assert(join.stats == expectedStats) } @@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) // Update column stats for equi-join keys (key-1-5 and key-1-2). - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it // unchanged (key-2-4). - val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5)) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. - val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) - val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1), + min = Some(false), max = Some(false), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toByte), max = Some(1.toByte), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toShort), max = Some(1.toShort), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1), max = Some(1), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1L), max = Some(1L), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0), max = Some(1.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0f), max = Some(1.0f), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(1), min = Some(dec), max = Some(dec), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1), + min = Some(date), max = Some(date), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1), + min = Some(timestamp), max = Some(timestamp), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) ) } @@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("join with null column") { val (nullColumn, nullColStat) = (attr("cnull"), - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4))) val nullTable = StatsTestPlan( outputList = Seq(nullColumn), rowCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index cda54fa9d64f4..dcb37017329fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.types._ class ProjectEstimationSuite extends StatsEstimationTestBase { test("project with alias") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1), - max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4)) - val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10), - max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1), + max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) + val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10), + max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1, ar2), @@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("project on empty table") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1), rowCount = 0, @@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2), + min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(6.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(7.0), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(2), min = Some(dec1), max = Some(dec2), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2), + min = Some(d1), max = Some(d2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2), + min = Some(t1), max = Some(t2), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) )) val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 31dea2e3e7f1d..9dceca59f5b87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen + case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4 + case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize) } def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() @@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite { val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) } + + /** Get a test ColumnStat with given distinctCount and nullCount */ + def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat = + ColumnStat(distinctCount = Some(distinctCount), + min = Some(1), max = Some(distinctCount), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1122522ccb4cb..640e01336aa75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** @@ -64,12 +66,12 @@ case class AnalyzeColumnCommand( /** * Compute stats for the given columns. - * @return (row count, map from column name to ColumnStats) + * @return (row count, map from column name to CatalogColumnStats) */ private def computeColumnStats( sparkSession: SparkSession, tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan @@ -81,7 +83,7 @@ case class AnalyzeColumnCommand( // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => - if (!ColumnStat.supportsType(attr.dataType)) { + if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") @@ -103,7 +105,7 @@ case class AnalyzeColumnCommand( // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) + attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -111,9 +113,9 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, - attributePercentiles.get(attr))) + // according to `statExprs`, the stats struct always have 7 fields. + (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) }.toMap (rowCount, columnStats) } @@ -124,7 +126,7 @@ case class AnalyzeColumnCommand( sparkSession: SparkSession, relation: LogicalPlan): AttributeMap[ArrayData] = { val attrsToGenHistogram = if (conf.histogramEnabled) { - attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) } else { Nil } @@ -154,4 +156,120 @@ case class AnalyzeColumnCommand( AttributeMap(attributePercentiles.toSeq) } + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case BinaryType | StringType => + // For string and binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e400975f19708..44749190c79eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -695,10 +695,11 @@ case class DescribeColumnCommand( // Show column stats when EXTENDED or FORMATTED is specified. buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) - buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL")) - buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) - buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) - buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("distinct_count", + cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL")) + buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL")) val histDesc = for { c <- cs hist <- c.histogram diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e798532056..ed4ea0231f1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -95,7 +96,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetchedStats2.get.sizeInBytes == 0) val expectedColStat = - "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + "key" -> CatalogColumnStat(Some(0), None, None, Some(0), + Some(IntegerType.defaultSize), Some(IntegerType.defaultSize)) // There won't be histogram for empty column. Seq("true", "false").foreach { histogramEnabled => @@ -156,7 +158,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(stats, statsWithHgms).foreach { s => s.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + val roundtrip = CatalogColumnStat.fromMap("table_is_foo", field.name, v.toMap(k)) assert(roundtrip == Some(v)) } } @@ -187,7 +189,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared }.mkString(", ")) val expectedColStats = dataTypes.map { case (tpe, idx) => - (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + (s"col$idx", CatalogColumnStat(Some(0), None, None, Some(1), + Some(tpe.defaultSize.toLong), Some(tpe.defaultSize.toLong))) } // There won't be histograms for null columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 65ccc1915882f..bf4abb6e625c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -67,18 +67,21 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) + "cbool" -> CatalogColumnStat(Some(2), Some("false"), Some("true"), Some(1), Some(1), Some(1)), + "cbyte" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(1), Some(1), Some(1)), + "cshort" -> CatalogColumnStat(Some(2), Some("1"), Some("3"), Some(1), Some(2), Some(2)), + "cint" -> CatalogColumnStat(Some(2), Some("1"), Some("4"), Some(1), Some(4), Some(4)), + "clong" -> CatalogColumnStat(Some(2), Some("1"), Some("5"), Some(1), Some(8), Some(8)), + "cdouble" -> CatalogColumnStat(Some(2), Some("1.0"), Some("6.0"), Some(1), Some(8), Some(8)), + "cfloat" -> CatalogColumnStat(Some(2), Some("1.0"), Some("7.0"), Some(1), Some(4), Some(4)), + "cdecimal" -> CatalogColumnStat(Some(2), Some(dec1.toString), Some(dec2.toString), Some(1), + Some(16), Some(16)), + "cstring" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cbinary" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cdate" -> CatalogColumnStat(Some(2), Some(d1.toString), Some(d2.toString), Some(1), Some(4), + Some(4)), + "ctimestamp" -> CatalogColumnStat(Some(2), Some(t1.toString), Some(t2.toString), Some(1), + Some(8), Some(8)) ) /** @@ -110,6 +113,110 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats } + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { @@ -151,7 +258,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils */ def checkColStats( df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) @@ -161,14 +268,24 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats.keys.mkString(", ")) // Validate statistics - val table = getCatalogTable(tableName) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } + validateColStats(tableName, colStats) + } + } + + /** + * Validate if the given catalog table has the provided statistics. + */ + def validateColStats( + tableName: String, + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { + + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) } } } @@ -215,12 +332,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val emptyColStat = ColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) + val emptyCatalogColStat = CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) // Check catalog statistics assert(catalogTable.stats.isDefined) assert(catalogTable.stats.get.sizeInBytes == 0) assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyCatalogColStat)) // Check relation statistics withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 1ee1d57b8ebe1..28c340a176d91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -663,14 +663,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert table statistics to properties so that we can persist them through hive client val statsProperties = if (stats.isDefined) { - statsToProperties(stats.get, schema) + statsToProperties(stats.get) } else { new mutable.HashMap[String, String]() } @@ -1028,9 +1024,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } - private def statsToProperties( - stats: CatalogStatistics, - schema: StructType): Map[String, String] = { + private def statsToProperties(stats: CatalogStatistics): Map[String, String] = { val statsProperties = new mutable.HashMap[String, String]() statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() @@ -1038,11 +1032,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - val colNameTypeMap: Map[String, DataType] = - schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + colStat.toMap(colName).foreach { case (k, v) => + // Fully qualified name used in table properties for a particular column stat. + // For example, for column "mycol", and "min" stat, this should return + // "spark.sql.statistics.colStats.mycol.min". + statsProperties += (STATISTICS_COL_STATS_PREFIX + k -> v) } } @@ -1058,23 +1053,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (statsProps.isEmpty) { None } else { + val colStats = new mutable.HashMap[String, CatalogColumnStat] + val colStatsProps = properties.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)).map { + case (k, v) => k.drop(STATISTICS_COL_STATS_PREFIX.length) -> v + } - val colStats = new mutable.HashMap[String, ColumnStat] - - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } - ColumnStat.fromMap(table, field, colStatMap).foreach { cs => - colStats += field.name -> cs - } + // Find all the column names by matching the KEY_VERSION properties for them. + colStatsProps.keys.filter { + k => k.endsWith(CatalogColumnStat.KEY_VERSION) + }.map { k => + k.dropRight(CatalogColumnStat.KEY_VERSION.length + 1) + }.foreach { fieldName => + // and for each, create a column stat. + CatalogColumnStat.fromMap(table, fieldName, colStatsProps).foreach { cs => + colStats += fieldName -> cs } } @@ -1093,14 +1085,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert partition statistics to properties so that we can persist them through hive api val withStatsProps = lowerCasedParts.map { p => if (p.stats.isDefined) { - val statsProperties = statsToProperties(p.stats.get, schema) + val statsProperties = statsToProperties(p.stats.get) p.copy(parameters = p.parameters ++ statsProperties) } else { p @@ -1310,15 +1298,6 @@ object HiveExternalCatalog { val EMPTY_DATA_SCHEMA = new StructType() .add("col", "array", nullable = true, comment = "from deserializer") - /** - * Returns the fully qualified name used in table properties for a particular column stat. - * For example, for column "mycol", and "min" stat, this should return - * "spark.sql.statistics.colStats.mycol.min". - */ - private def columnStatKeyPropName(columnName: String, statKey: String): String = { - STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey - } - // A persisted data source table always store its schema in the catalog. private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3af8af0814bb4..61cec82984795 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils @@ -177,8 +177,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats0.get.colStats == Map( - "a" -> ColumnStat(2, Some(1), Some(2), 0, 4, 4), - "b" -> ColumnStat(1, Some(1), Some(1), 0, 4, 4))) + "a" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(0), Some(4), Some(4)), + "b" -> CatalogColumnStat(Some(1), Some("1"), Some("1"), Some(0), Some(4), Some(4)))) } } @@ -208,8 +208,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "C1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "C1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) } } @@ -596,7 +596,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + assert(fetchedStats0.get.colStats == + Map("c1" -> CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)))) // Insert new data and analyze: have the latest column stats. sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") @@ -604,18 +605,18 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) // Analyze another column: since the table is not changed, the precious column stats are kept. sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats2.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4), - "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, - avgLen = 1, maxLen = 1))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c2" -> CatalogColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)))) // Insert new data and analyze: stale column stats are removed and newly collected column // stats are added. @@ -624,10 +625,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get assert(fetchedStats3.colStats == Map( - "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, - avgLen = 8, maxLen = 8))) + "c1" -> CatalogColumnStat(distinctCount = Some(2), min = Some("1"), max = Some("2"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(2), min = Some("10.0"), max = Some("20.0"), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)))) } } @@ -999,115 +1000,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("verify serialized column stats after analyzing columns") { import testImplicits._ - val tableName = "column_stats_test2" + val tableName = "column_stats_test_ser" // (data.head.productArity - 1) because the last column does not support stats collection. assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - val expectedSerializedColStats = Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - ) - - val expectedSerializedHistograms = Map( - "spark.sql.statistics.colStats.cbyte.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), - "spark.sql.statistics.colStats.cshort.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), - "spark.sql.statistics.colStats.cint.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), - "spark.sql.statistics.colStats.clong.histogram" -> - HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), - "spark.sql.statistics.colStats.cdouble.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), - "spark.sql.statistics.colStats.cfloat.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), - "spark.sql.statistics.colStats.cdecimal.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), - "spark.sql.statistics.colStats.cdate.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), - "spark.sql.statistics.colStats.ctimestamp.histogram" -> - HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) - ) - def checkColStatsProps(expected: Map[String, String]): Unit = { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) val table = hiveClient.getTable("default", tableName) @@ -1129,6 +1026,29 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("verify column stats can be deserialized from tblproperties") { + import testImplicits._ + + val tableName = "column_stats_test_de" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Put in stats properties manually. + val table = getCatalogTable(tableName) + val newTable = table.copy( + properties = table.properties ++ + expectedSerializedColStats ++ expectedSerializedHistograms + + ("spark.sql.statistics.totalSize" -> "1") /* totalSize always required */) + hiveClient.alterTable(newTable) + + validateColStats(tableName, statsWithHgms) + } + } + test("serialization and deserialization of histograms to/from hive metastore") { import testImplicits._ From 649ed9c5732f85ef1306576fdd3a9278a2a6410c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Feb 2018 08:18:41 -0600 Subject: [PATCH 088/593] [SPARK-23509][BUILD] Upgrade commons-net from 2.2 to 3.1 ## What changes were proposed in this pull request? This PR avoids version conflicts of `commons-net` by upgrading commons-net from 2.2 to 3.1. We are seeing the following message during the build using sbt. ``` [warn] Found version conflict(s) in library dependencies; some are suspected to be binary incompatible: ... [warn] * commons-net:commons-net:3.1 is selected over 2.2 [warn] +- org.apache.hadoop:hadoop-common:2.6.5 (depends on 3.1) [warn] +- org.apache.spark:spark-core_2.11:2.4.0-SNAPSHOT (depends on 2.2) [warn] ``` [Here](https://commons.apache.org/proper/commons-net/changes-report.html) is a release history. [Here](https://commons.apache.org/proper/commons-net/migration.html) is a migration guide from 2.x to 3.0. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20672 from kiszk/SPARK-23509. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index ed310507d14ed..c3d1dd444b506 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 04dec04796af4..290867035f91d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/pom.xml b/pom.xml index ac30107066389..b8396166f6b1b 100644 --- a/pom.xml +++ b/pom.xml @@ -579,7 +579,7 @@ commons-net commons-net - 2.2 + 3.1 io.netty From eac0b067222a3dfa52be20360a453cb7bd420bf2 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 27 Feb 2018 08:21:11 -0600 Subject: [PATCH 089/593] [SPARK-17147][STREAMING][KAFKA] Allow non-consecutive offsets ## What changes were proposed in this pull request? Add a configuration spark.streaming.kafka.allowNonConsecutiveOffsets to allow streaming jobs to proceed on compacted topics (or other situations involving gaps between offsets in the log). ## How was this patch tested? Added new unit test justinrmiller has been testing this branch in production for a few weeks Author: cody koeninger Closes #20572 from koeninger/SPARK-17147. --- .../kafka010/CachedKafkaConsumer.scala | 55 +++- .../spark/streaming/kafka010/KafkaRDD.scala | 236 +++++++++++++----- .../streaming/kafka010/KafkaRDDSuite.scala | 106 ++++++++ .../streaming/kafka010/KafkaTestUtils.scala | 25 +- .../kafka010/mocks/MockScheduler.scala | 96 +++++++ .../streaming/kafka010/mocks/MockTime.scala | 51 ++++ 6 files changed, 487 insertions(+), 82 deletions(-) create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala index fa3ea6131a507..aeb8c1dc342b3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -22,10 +22,8 @@ import java.{ util => ju } import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } import org.apache.kafka.common.{ KafkaException, TopicPartition } -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging - /** * Consumer of single topicpartition, intended for cached reuse. * Underlying consumer is not threadsafe, so neither is this, @@ -38,7 +36,7 @@ class CachedKafkaConsumer[K, V] private( val partition: Int, val kafkaParams: ju.Map[String, Object]) extends Logging { - assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), + require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), "groupId used for cache key must match the groupId in kafkaParams") val topicPartition = new TopicPartition(topic, partition) @@ -53,7 +51,7 @@ class CachedKafkaConsumer[K, V] private( // TODO if the buffer was kept around as a random-access structure, // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator + protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() protected var nextOffset = -2L def close(): Unit = consumer.close() @@ -71,7 +69,7 @@ class CachedKafkaConsumer[K, V] private( } if (!buffer.hasNext()) { poll(timeout) } - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") var record = buffer.next() @@ -79,17 +77,56 @@ class CachedKafkaConsumer[K, V] private( logInfo(s"Buffer miss for $groupId $topic $partition $offset") seek(offset) poll(timeout) - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") record = buffer.next() - assert(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset") + require(record.offset == offset, + s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) } nextOffset = offset + 1 record } + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, timeout: Long): Unit = { + logDebug(s"compacted start $groupId $topic $partition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(timeout: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + private def seek(offset: Long): Unit = { logDebug(s"Seeking to $topicPartition $offset") consumer.seek(topicPartition, offset) @@ -99,7 +136,7 @@ class CachedKafkaConsumer[K, V] private( val p = consumer.poll(timeout) val r = p.records(topicPartition) logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.iterator + buffer = r.listIterator } } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index d9fc9cc206647..07239eda64d2e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -55,12 +55,12 @@ private[spark] class KafkaRDD[K, V]( useConsumerCache: Boolean ) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges { - assert("none" == + require("none" == kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String], ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " must be set to none for executor kafka params, else messages may not match offsetRange") - assert(false == + require(false == kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean], ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + " must be set to false for executor kafka params, else offsets may commit before processing") @@ -74,6 +74,8 @@ private[spark] class KafkaRDD[K, V]( conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64) private val cacheLoadFactor = conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat + private val compacted = + conf.getBoolean("spark.streaming.kafka.allowNonConsecutiveOffsets", false) override def persist(newLevel: StorageLevel): this.type = { logError("Kafka ConsumerRecord is not serializable. " + @@ -87,48 +89,63 @@ private[spark] class KafkaRDD[K, V]( }.toArray } - override def count(): Long = offsetRanges.map(_.count).sum + override def count(): Long = + if (compacted) { + super.count() + } else { + offsetRanges.map(_.count).sum + } override def countApprox( timeout: Long, confidence: Double = 0.95 - ): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[K, V]] = { - val nonEmptyPartitions = this.partitions - .map(_.asInstanceOf[KafkaRDDPartition]) - .filter(_.count > 0) + ): PartialResult[BoundedDouble] = + if (compacted) { + super.countApprox(timeout, confidence) + } else { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[K, V]](0) + override def isEmpty(): Boolean = + if (compacted) { + super.isEmpty() + } else { + count == 0L } - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.count) - result + (part.index -> taken.toInt) + override def take(num: Int): Array[ConsumerRecord[K, V]] = + if (compacted) { + super.take(num) + } else if (num < 1) { + Array.empty[ConsumerRecord[K, V]] + } else { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (nonEmptyPartitions.isEmpty) { + Array.empty[ConsumerRecord[K, V]] } else { - result + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ).flatten } } - val buf = new ArrayBuffer[ConsumerRecord[K, V]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - private def executors(): Array[ExecutorCacheTaskLocation] = { val bm = sparkContext.env.blockManager bm.master.getPeers(bm.blockManagerId).toArray @@ -172,57 +189,138 @@ private[spark] class KafkaRDD[K, V]( override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = { val part = thePart.asInstanceOf[KafkaRDDPartition] - assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { - new KafkaRDDIterator(part, context) + logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + if (compacted) { + new CompactedKafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } else { + new KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } } } +} - /** - * An iterator that fetches messages directly from Kafka for the offsets in partition. - * Uses a cached consumer where possible to take advantage of prefetching - */ - private class KafkaRDDIterator( - part: KafkaRDDPartition, - context: TaskContext) extends Iterator[ConsumerRecord[K, V]] { - - logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + - s"offsets ${part.fromOffset} -> ${part.untilOffset}") - - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching + */ +private class KafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float +) extends Iterator[ConsumerRecord[K, V]] { + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + context.addTaskCompletionListener(_ => closeIfNeeded()) + + val consumer = if (useConsumerCache) { + CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + if (context.attemptNumber >= 1) { + // just in case the prior attempt failures were cache related + CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + } + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + } else { + CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + } - context.addTaskCompletionListener{ context => closeIfNeeded() } + var requestOffset = part.fromOffset - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + def closeIfNeeded(): Unit = { + if (!useConsumerCache && consumer != null) { + consumer.close() } + } - var requestOffset = part.fromOffset + override def hasNext(): Boolean = requestOffset < part.untilOffset - def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close - } + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") } + val r = consumer.get(requestOffset, pollTimeout) + requestOffset += 1 + r + } +} - override def hasNext(): Boolean = requestOffset < part.untilOffset - - override def next(): ConsumerRecord[K, V] = { - assert(hasNext(), "Can't call getNext() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeout) - requestOffset += 1 - r +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching. + * Intended for compacted topics, or other cases when non-consecutive offsets are ok. + */ +private class CompactedKafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float + ) extends KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) { + + consumer.compactedStart(part.fromOffset, pollTimeout) + + private var nextRecord = consumer.compactedNext(pollTimeout) + + private var okNext: Boolean = true + + override def hasNext(): Boolean = okNext + + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") + } + val r = nextRecord + if (r.offset + 1 >= part.untilOffset) { + okNext = false + } else { + nextRecord = consumer.compactedNext(pollTimeout) + if (nextRecord.offset >= part.untilOffset) { + okNext = false + consumer.compactedPrevious() + } } + r } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index be373af0599cc..271adea1df731 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -18,16 +18,22 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } +import java.io.File import scala.collection.JavaConverters._ import scala.util.Random +import kafka.common.TopicAndPartition +import kafka.log._ +import kafka.message._ +import kafka.utils.Pool import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.streaming.kafka010.mocks.MockTime class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -64,6 +70,41 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private val preferredHosts = LocationStrategies.PreferConsistent + private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { + val mockTime = new MockTime() + // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api + val logs = new Pool[TopicAndPartition, Log]() + val logDir = kafkaTestUtils.brokerLogDir + val dir = new File(logDir, topic + "-" + partition) + dir.mkdirs() + val logProps = new ju.Properties() + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val log = new Log( + dir, + LogConfig(logProps), + 0L, + mockTime.scheduler, + mockTime + ) + messages.foreach { case (k, v) => + val msg = new ByteBufferMessageSet( + NoCompressionCodec, + new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) + log.append(msg) + } + log.roll() + logs.put(TopicAndPartition(topic, partition), log) + + val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + cleaner.startup() + cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + + cleaner.shutdown() + mockTime.scheduler.shutdown() + } + + test("basic usage") { val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" kafkaTestUtils.createTopic(topic) @@ -102,6 +143,71 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("compacted topic") { + val compactConf = sparkConf.clone() + compactConf.set("spark.streaming.kafka.allowNonConsecutiveOffsets", "true") + sc.stop() + sc = new SparkContext(compactConf) + val topic = s"topiccompacted-${Random.nextInt}-${System.currentTimeMillis}" + + val messages = Array( + ("a", "1"), + ("a", "2"), + ("b", "1"), + ("c", "1"), + ("c", "2"), + ("b", "2"), + ("b", "3") + ) + val compactedMessages = Array( + ("a", "2"), + ("b", "3"), + ("c", "2") + ) + + compactLogs(topic, 0, messages) + + val props = new ju.Properties() + props.put("cleanup.policy", "compact") + props.put("flush.messages", "1") + props.put("segment.ms", "1") + props.put("segment.bytes", "256") + kafkaTestUtils.createTopic(topic, 1, props) + + + val kafkaParams = getKafkaParams() + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, offsetRanges, preferredHosts + ).map(m => m.key -> m.value) + + val received = rdd.collect.toSet + assert(received === compactedMessages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === compactedMessages.size) + assert(rdd.countApprox(0).getFinalValue.mean === compactedMessages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head === compactedMessages.head) + assert(rdd.take(messages.size + 10).size === compactedMessages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts) + .map(_.value) + .collect() + } + } + test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 6c7024ea4b5a5..70b579d96d692 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -162,17 +162,22 @@ private[kafka010] class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, config: Properties): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1, config) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + createTopic(topic, partitions, new Properties()) + } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { - createTopic(topic, 1) + createTopic(topic, 1, new Properties()) } /** Java-friendly function for sending messages to the Kafka broker */ @@ -196,12 +201,24 @@ private[kafka010] class KafkaTestUtils extends Logging { producer = null } + /** Send the array of (key, value) messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[(String, String)]): Unit = { + producer = new KafkaProducer[String, String](producerConfiguration) + messages.foreach { message => + producer.send(new ProducerRecord[String, String](topic, message._1, message._2)) + } + producer.close() + producer = null + } + + val brokerLogDir = Utils.createTempDir().getAbsolutePath + private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala new file mode 100644 index 0000000000000..928e1a6ef54b9 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.PriorityQueue + +import kafka.utils.{Scheduler, Time} + +/** + * A mock scheduler that executes tasks synchronously using a mock time instance. + * Tasks are executed synchronously when the time is advanced. + * This class is meant to be used in conjunction with MockTime. + * + * Example usage + * + * val time = new MockTime + * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000) + * time.sleep(1001) // this should cause our scheduled task to fire + * + * + * Incrementing the time to the exact next execution time of a task will result in that task + * executing (it as if execution itself takes no time). + */ +private[kafka010] class MockScheduler(val time: Time) extends Scheduler { + + /* a priority queue of tasks ordered by next execution time */ + var tasks = new PriorityQueue[MockTask]() + + def isStarted: Boolean = true + + def startup(): Unit = {} + + def shutdown(): Unit = synchronized { + tasks.foreach(_.fun()) + tasks.clear() + } + + /** + * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs + * when this method is called and the execution happens synchronously in the calling thread. + * If you are using the scheduler associated with a MockTime instance this call + * will be triggered automatically. + */ + def tick(): Unit = synchronized { + val now = time.milliseconds + while(!tasks.isEmpty && tasks.head.nextExecution <= now) { + /* pop and execute the task with the lowest next execution time */ + val curr = tasks.dequeue + curr.fun() + /* if the task is periodic, reschedule it and re-enqueue */ + if(curr.periodic) { + curr.nextExecution += curr.period + this.tasks += curr + } + } + } + + def schedule( + name: String, + fun: () => Unit, + delay: Long = 0, + period: Long = -1, + unit: TimeUnit = TimeUnit.MILLISECONDS): Unit = synchronized { + tasks += MockTask(name, fun, time.milliseconds + delay, period = period) + tick() + } + +} + +case class MockTask( + val name: String, + val fun: () => Unit, + var nextExecution: Long, + val period: Long) extends Ordered[MockTask] { + def periodic: Boolean = period >= 0 + def compare(t: MockTask): Int = { + java.lang.Long.compare(t.nextExecution, nextExecution) + } +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala new file mode 100644 index 0000000000000..a68f94db1f689 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent._ + +import kafka.utils.Time + +/** + * A class used for unit testing things which depend on the Time interface. + * + * This class never manually advances the clock, it only does so when you call + * sleep(ms) + * + * It also comes with an associated scheduler instance for managing background tasks in + * a deterministic way. + */ +private[kafka010] class MockTime(@volatile private var currentMs: Long) extends Time { + + val scheduler = new MockScheduler(this) + + def this() = this(System.currentTimeMillis) + + def milliseconds: Long = currentMs + + def nanoseconds: Long = + TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) + + def sleep(ms: Long) { + this.currentMs += ms + scheduler.tick() + } + + override def toString(): String = s"MockTime($milliseconds)" + +} From 414ee867ba0835b97aae2e8d4e489e1879c251dd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Feb 2018 08:44:25 -0800 Subject: [PATCH 090/593] [SPARK-23523][SQL] Fix the incorrect result caused by the rule OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? ```Scala val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") .write.json(tablePath.getCanonicalPath) val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() df.show() ``` It generates a wrong result. ``` [c,e,a] ``` We have a bug in the rule `OptimizeMetadataOnlyQuery `. We should respect the attribute order in the original leaf node. This PR is to fix it. ## How was this patch tested? Added a test case Author: gatorsmile Closes #20684 from gatorsmile/optimizeMetadataOnly. --- .../plans/logical/LocalRelation.scala | 9 ++++---- .../execution/OptimizeMetadataOnlyQuery.scala | 12 ++++++++-- .../datasources/HadoopFsRelation.scala | 3 +++ .../OptimizeMetadataOnlyQuerySuite.scala | 22 +++++++++++++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d73d7e73f28d5..b05508db786ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,10 +43,11 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], - data: Seq[InternalRow] = Nil, - // Indicates whether this relation has data from a streaming source. - override val isStreaming: Boolean = false) +case class LocalRelation( + output: Seq[Attribute], + data: Seq[InternalRow] = Nil, + // Indicates whether this relation has data from a streaming source. + override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 18f6f697bc857..0613d9053f826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import java.util.Locale + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -80,8 +83,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val partColumns = partitionColumnNames.map(_.toLowerCase).toSet - relation.output.filter(a => partColumns.contains(a.name.toLowerCase)) + val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + partitionColumnNames.map { colName => + attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), + throw new AnalysisException(s"Unable to find the column `$colName` " + + s"given [${relation.output.map(_.name).mkString(", ")}]") + ) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 6b34638529770..ac574b07ec497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,6 +67,9 @@ case class HadoopFsRelation( } } + // When data schema and partition schema have the overlapped columns, the output + // schema respects the order of data schema for the overlapped columns, but respect + // the data types of partition schema val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 78c1e5dae566d..a543eb8351656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.io.File + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SharedSQLContext class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { @@ -125,4 +128,23 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() } } + + test("Incorrect result caused by the rule OptimizeMetadataOnlyQuery") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + withTempPath { path => + val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") + Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") + .write.json(tablePath.getCanonicalPath) + + val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() + checkAnswer(df, Row("a", "e", "c")) + + val localRelation = df.queryExecution.optimizedPlan.collectFirst { + case l: LocalRelation => l + } + assert(localRelation.nonEmpty, "expect to see a LocalRelation") + assert(localRelation.get.output.map(_.name) == Seq("cOl3", "cOl1", "cOl5")) + } + } + } } From ecb8b383af1cf1b67f3111c148229e00c9c17c40 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 27 Feb 2018 11:12:32 -0800 Subject: [PATCH 091/593] [SPARK-23365][CORE] Do not adjust num executors when killing idle executors. The ExecutorAllocationManager should not adjust the target number of executors when killing idle executors, as it has already adjusted the target number down based on the task backlog. The name `replace` was misleading with DynamicAllocation on, as the target number of executors is changed outside of the call to `killExecutors`, so I adjusted that name. Also separated out the logic of `countFailures` as you don't always want that tied to `replace`. While I was there I made two changes that weren't directly related to this: 1) Fixed `countFailures` in a couple cases where it was getting an incorrect value since it used to be tied to `replace`, eg. when killing executors on a blacklisted node. 2) hard error if you call `sc.killExecutors` with dynamic allocation on, since that's another way the ExecutorAllocationManager and the CoarseGrainedSchedulerBackend would get out of sync. Added a unit test case which verifies that the calls to ExecutorAllocationClient do not adjust the number of executors. Author: Imran Rashid Closes #20604 from squito/SPARK-23365. --- .../spark/ExecutorAllocationClient.scala | 15 +++-- .../spark/ExecutorAllocationManager.scala | 20 ++++-- .../scala/org/apache/spark/SparkContext.scala | 13 +++- .../spark/scheduler/BlacklistTracker.scala | 3 +- .../CoarseGrainedSchedulerBackend.scala | 22 ++++--- .../ExecutorAllocationManagerSuite.scala | 66 ++++++++++++++++++- .../StandaloneDynamicAllocationSuite.scala | 3 +- .../scheduler/BlacklistTrackerSuite.scala | 14 ++-- 8 files changed, 121 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 9112d93a86b2a..63d87b4cd385c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -55,18 +55,18 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors will be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * to count those failures toward task failure limits * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ def killExecutors( executorIds: Seq[String], - replace: Boolean = false, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean = false): Seq[String] /** @@ -81,7 +81,8 @@ private[spark] trait ExecutorAllocationClient { * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = { - val killedExecutors = killExecutors(Seq(executorId)) + val killedExecutors = killExecutors(Seq(executorId), adjustTargetNumExecutors = true, + countFailures = false) killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6c59038f2a6c1..189d91333c045 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** @@ -81,7 +82,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, listenerBus: LiveListenerBus, - conf: SparkConf) + conf: SparkConf, + blockManagerMaster: BlockManagerMaster) extends Logging { allocationManager => @@ -151,7 +153,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener + val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. private val executor = @@ -334,6 +336,11 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { + // We lower the target number of executors but don't actively kill any yet. Killing is + // controlled separately by an idle timeout. It's still helpful to reduce the target number + // in case an executor just happens to get lost (eg., bad hardware, or the cluster manager + // preempts it) -- in that case, there is no point in trying to immediately get a new + // executor, since we wouldn't even use it yet. client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") @@ -455,7 +462,10 @@ private[spark] class ExecutorAllocationManager( val executorsRemoved = if (testing) { executorIdsToBeRemoved } else { - client.killExecutors(executorIdsToBeRemoved) + // We don't want to change our target number of executors, because we already did that + // when the task backlog decreased. + client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + countFailures = false, force = false) } // [SPARK-21834] killExecutors api reduces the target number of executors. // So we need to update the target with desired value. @@ -575,7 +585,7 @@ private[spark] class ExecutorAllocationManager( // Note that it is not necessary to query the executors since all the cached // blocks we are concerned with are reported to the driver. Note that this // does not include broadcast blocks. - val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId) val now = clock.getTimeMillis() val timeout = { if (hasCachedBlocks) { @@ -610,7 +620,7 @@ private[spark] class ExecutorAllocationManager( * This class is intentionally conservative in its assumptions about the relative ordering * and consistency of events returned by the listener. */ - private class ExecutorAllocationListener extends SparkListener { + private[spark] class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] // Number of running tasks per stage including speculative tasks. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dc531e3337014..5e8595603cc90 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -534,7 +534,8 @@ class SparkContext(config: SparkConf) extends Logging { schedulerBackend match { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( - schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, + _env.blockManager.master)) case _ => None } @@ -1633,6 +1634,8 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * + * This is not supported when dynamic allocation is turned on. + * * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to @@ -1644,7 +1647,10 @@ class SparkContext(config: SparkConf) extends Logging { def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(executorIds, replace = false, force = true).nonEmpty + require(executorAllocationManager.isEmpty, + "killExecutors() unsupported with Dynamic Allocation turned on") + b.killExecutors(executorIds, adjustTargetNumExecutors = true, countFailures = false, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false @@ -1682,7 +1688,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(Seq(executorId), replace = true, force = true).nonEmpty + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = false, countFailures = true, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index cd8e61d6d0208..952598f6de19d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -152,7 +152,8 @@ private[scheduler] class BlacklistTracker ( case Some(a) => logInfo(s"Killing blacklisted executor id $exec " + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") - a.killExecutors(Seq(exec), true, true) + a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, + force = true) case None => logWarning(s"Not attempting to kill blacklisted executor id $exec " + s"since allocation client is not defined.") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4d75063fbf1c5..5627a557a12f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -147,7 +147,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case KillExecutorsOnHost(host) => scheduler.getExecutorsAliveOnHost(host).foreach { exec => - killExecutors(exec.toSeq, replace = true, force = true) + killExecutors(exec.toSeq, adjustTargetNumExecutors = false, countFailures = false, + force = true) } case UpdateDelegationTokens(newDelegationTokens) => @@ -584,18 +585,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * those failures be counted to task failure limits? * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ final override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") @@ -610,7 +611,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") @@ -618,12 +619,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. val adjustTotalExecutors = - if (!replace) { + if (adjustTargetNumExecutors) { requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) if (requestedTotalExecutors != (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { logDebug( - s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force): + |Executor counts do not match: |requestedTotalExecutors = $requestedTotalExecutors |numExistingExecutors = $numExistingExecutors |numPendingExecutors = $numPendingExecutors diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a0cae5a9e011c..9807d1269e3d4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import scala.collection.mutable +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics @@ -26,6 +28,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.ManualClock /** @@ -1050,6 +1053,66 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager) === Map.empty) } + test("SPARK-23365 Don't update target num executors when killing idle executors") { + val minExecutors = 1 + val initialExecutors = 1 + val maxExecutors = 2 + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) + .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms") + val mockAllocationClient = mock(classOf[ExecutorAllocationClient]) + val mockBMM = mock(classOf[BlockManagerMaster]) + val manager = new ExecutorAllocationManager( + mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM) + val clock = new ManualClock() + manager.setClock(clock) + + when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true) + // test setup -- job with 2 tasks, scale up to two executors + assert(numExecutorsTarget(manager) === 1) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty))) + manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2))) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 2) + val taskInfo0 = createTaskInfo(0, 0, "executor-1") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0)) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty))) + val taskInfo1 = createTaskInfo(1, 1, "executor-2") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1)) + assert(numExecutorsTarget(manager) === 2) + + // have one task finish -- we should adjust the target number of executors down + // but we should *not* kill any executors yet + manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 1) + verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any()) + + // now we cross the idle timeout for executor-1, so we kill it. the really important + // thing here is that we do *not* ask the executor allocation client to adjust the target + // number of executors down + when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false)) + .thenReturn(Seq("executor-1")) + clock.advance(3000) + schedule(manager) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 1) + // here's the important verify -- we did kill the executors, but did not adjust the target count + verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -1268,7 +1331,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = executorIds override def start(): Unit = sb.start() diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index c21ee7d26f8ca..27cc47496c805 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite syncExecutors(sc) sc.schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = false, force) + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false, + force) case _ => fail("expected coarse grained scheduler") } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index afebcdd7b9e31..06d7afaaff55c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1) // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole @@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) - verify(allocationClientMock).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutors(Seq("2"), false, false, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -571,7 +571,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M conf.set(config.BLACKLIST_KILL_ENABLED, false) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -580,7 +580,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M clock.advance(1000) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) verify(allocationClientMock, never).killExecutorsOnHost(any()) assert(blacklist.executorIdToBlacklistStatus.contains("1")) From 598446b74b61fee272d3aee3a2e9a3fc90a70d6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 27 Feb 2018 11:33:10 -0800 Subject: [PATCH 092/593] [SPARK-23501][UI] Refactor AllStagesPage in order to avoid redundant code As suggested in #20651, the code is very redundant in `AllStagesPage` and modifying it is a copy-and-paste work. We should avoid such a pattern, which is error prone, and have a cleaner solution which avoids code redundancy. existing UTs Author: Marco Gaido Closes #20663 from mgaido91/SPARK-23475_followup. --- .../apache/spark/ui/jobs/AllStagesPage.scala | 261 +++++++----------- 1 file changed, 102 insertions(+), 159 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 38450b9126ff0..4658aa1cea3f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -19,46 +19,20 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, NodeSeq} +import scala.xml.{Attribute, Elem, Node, NodeSeq, Null, Text} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.status.PoolData -import org.apache.spark.status.api.v1._ +import org.apache.spark.status.{AppSummary, PoolData} +import org.apache.spark.status.api.v1.{StageData, StageStatus} import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc + private val subPath = "stages" private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - val allStages = parent.store.stageList(null) - - val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) - val pendingStages = allStages.filter(_.status == StageStatus.PENDING) - val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) - val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) - val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - - val numFailedStages = failedStages.size - val subPath = "stages" - - val activeStagesTable = - new StageTableBase(parent.store, request, activeStages, "active", "activeStage", - parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) - val pendingStagesTable = - new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val completedStagesTable = - new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val skippedStagesTable = - new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val failedStagesTable = - new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", - parent.basePath, subPath, parent.isFairScheduler, false, true) - // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( @@ -67,152 +41,121 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { }.toMap val poolTable = new PoolTable(pools, parent) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowSkippedStages = skippedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val allStatuses = Seq(StageStatus.ACTIVE, StageStatus.PENDING, StageStatus.COMPLETE, + StageStatus.SKIPPED, StageStatus.FAILED) + val allStages = parent.store.stageList(null) val appSummary = parent.store.appSummary() - val completedStageNumStr = if (appSummary.numCompletedStages == completedStages.size) { - s"${appSummary.numCompletedStages}" - } else { - s"${appSummary.numCompletedStages}, only showing ${completedStages.size}" - } + + val (summaries, tables) = allStatuses.map( + summaryAndTableForStatus(allStages, appSummary, _, request)).unzip val summary: NodeSeq =
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - Pending Stages: - {pendingStages.size} -
    • - } - } - { - if (shouldShowCompletedStages) { -
    • - Completed Stages: - {completedStageNumStr} -
    • - } - } - { - if (shouldShowSkippedStages) { -
    • - Skipped Stages: - {skippedStages.size} -
    • - } - } - { - if (shouldShowFailedStages) { -
    • - Failed Stages: - {numFailedStages} -
    • - } - } + {summaries.flatten}
    - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { - -

    - - Fair Scheduler Pools ({pools.size}) -

    -
    ++ -
    - {poolTable.toNodeSeq} -
    - } else { - Seq.empty[Node] - } - } - if (shouldShowActiveStages) { - content ++= - -

    - - Active Stages ({activeStages.size}) -

    -
    ++ -
    - {activeStagesTable.toNodeSeq} -
    - } - if (shouldShowPendingStages) { - content ++= - + val poolsDescription = if (sc.isDefined && isFairScheduler) { +

    - Pending Stages ({pendingStages.size}) + Fair Scheduler Pools ({pools.size})

    ++ -
    - {pendingStagesTable.toNodeSeq} +
    + {poolTable.toNodeSeq}
    + } else { + Seq.empty[Node] + } + + val content = summary ++ poolsDescription ++ tables.flatten.flatten + + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def summaryAndTableForStatus( + allStages: Seq[StageData], + appSummary: AppSummary, + status: StageStatus, + request: HttpServletRequest): (Option[Elem], Option[NodeSeq]) = { + val stages = if (status == StageStatus.FAILED) { + allStages.filter(_.status == status).reverse + } else { + allStages.filter(_.status == status) } - if (shouldShowCompletedStages) { - content ++= - -

    - - Completed Stages ({completedStageNumStr}) -

    -
    ++ -
    - {completedStagesTable.toNodeSeq} -
    + + if (stages.isEmpty) { + (None, None) + } else { + val killEnabled = status == StageStatus.ACTIVE && parent.killEnabled + val isFailedStage = status == StageStatus.FAILED + + val stagesTable = + new StageTableBase(parent.store, request, stages, statusName(status), stageTag(status), + parent.basePath, subPath, parent.isFairScheduler, killEnabled, isFailedStage) + val stagesSize = stages.size + (Some(summary(appSummary, status, stagesSize)), + Some(table(appSummary, status, stagesTable, stagesSize))) } - if (shouldShowSkippedStages) { - content ++= - -

    - - Skipped Stages ({skippedStages.size}) -

    -
    ++ -
    - {skippedStagesTable.toNodeSeq} -
    + } + + private def statusName(status: StageStatus): String = status match { + case StageStatus.ACTIVE => "active" + case StageStatus.COMPLETE => "completed" + case StageStatus.FAILED => "failed" + case StageStatus.PENDING => "pending" + case StageStatus.SKIPPED => "skipped" + } + + private def stageTag(status: StageStatus): String = s"${statusName(status)}Stage" + + private def headerDescription(status: StageStatus): String = statusName(status).capitalize + + private def summaryContent(appSummary: AppSummary, status: StageStatus, size: Int): String = { + if (status == StageStatus.COMPLETE && appSummary.numCompletedStages != size) { + s"${appSummary.numCompletedStages}, only showing $size" + } else { + s"$size" } - if (shouldShowFailedStages) { - content ++= - -

    - - Failed Stages ({numFailedStages}) -

    -
    ++ -
    - {failedStagesTable.toNodeSeq} -
    + } + + private def summary(appSummary: AppSummary, status: StageStatus, size: Int): Elem = { + val summary = +
  • + + {headerDescription(status)} Stages: + + {summaryContent(appSummary, status, size)} +
  • + + if (status == StageStatus.COMPLETE) { + summary % Attribute(None, "id", Text("completed-summary"), Null) + } else { + summary } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def table( + appSummary: AppSummary, + status: StageStatus, + stagesTable: StageTableBase, + size: Int): NodeSeq = { + val classSuffix = s"${statusName(status).capitalize}Stages" + +

    + + {headerDescription(status)} Stages ({summaryContent(appSummary, status, size)}) +

    +
    ++ +
    + {stagesTable.toNodeSeq} +
    } } From 23ac3aaba4a33bc3d31d01f21e93c4681ef6de03 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 28 Feb 2018 09:25:02 +0900 Subject: [PATCH 093/593] [SPARK-23417][PYTHON] Fix the build instructions supplied by exception messages in python streaming tests ## What changes were proposed in this pull request? Fix the build instructions supplied by exception messages in python streaming tests. I also added -DskipTests to the maven instructions to avoid the 170 minutes of scala tests that occurs each time one wants to add a jar to the assembly directory. ## How was this patch tested? - clone branch - run build/sbt package - run python/run-tests --modules "pyspark-streaming" , expect error message - follow instructions in error message. i.e., run build/sbt assembly/package streaming-kafka-0-8-assembly/assembly - rerun python tests, expect error message - follow instructions in error message. i.e run build/sbt -Pflume assembly/package streaming-flume-assembly/assembly - rerun python tests, see success. - repeated all of the above for mvn version of the process. Author: Bruce Robbins Closes #20638 from bersprockets/SPARK-23417_propa. --- python/pyspark/streaming/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..71f8101e34c50 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1477,8 +1477,8 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/sbt -Pkafka-0-8 assembly/package streaming-kafka-0-8-assembly/assembly' or " + "'build/mvn -DskipTests -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1494,8 +1494,8 @@ def search_flume_assembly_jar(): raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or " + "'build/mvn -DskipTests -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) From b14993e1fcb68e1c946a671c6048605ab4afdf58 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Feb 2018 11:00:54 +0900 Subject: [PATCH 094/593] [SPARK-23448][SQL] Clarify JSON and CSV parser behavior in document ## What changes were proposed in this pull request? Clarify JSON and CSV reader behavior in document. JSON doesn't support partial results for corrupted records. CSV only supports partial results for the records with more or less tokens. ## How was this patch tested? Pass existing tests. Author: Liang-Chi Hsieh Closes #20666 from viirya/SPARK-23448-2. --- python/pyspark/sql/readwriter.py | 30 ++++++++++--------- python/pyspark/sql/streaming.py | 30 ++++++++++--------- .../sql/catalyst/json/JacksonParser.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 22 +++++++------- .../datasources/csv/UnivocityParser.scala | 5 ++++ .../sql/streaming/DataStreamReader.scala | 22 +++++++------- 6 files changed, 64 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 49af1bcee5ef8..9d05ac7cb39be 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -209,13 +209,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -393,13 +393,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e2a97acb5e2a7..cc622decfd682 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -442,13 +442,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -621,13 +621,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index bd144c9575c72..7f6956994f31f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -357,6 +357,9 @@ class JacksonParser( } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => + // JSON parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4274f120a375a..0139913aaa4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -345,12 +345,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -550,12 +550,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 7d6d7e7eef926..3d6cc30f2ba83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,6 +203,8 @@ class UnivocityParser( case _: BadRecordException => None } } + // For records with less or more tokens than the schema, tries to return partial results + // if possible. throw BadRecordException( () => getCurrentInput, () => getPartialResult(), @@ -218,6 +220,9 @@ class UnivocityParser( row } catch { case NonFatal(e) => + // For corrupted records with the number of tokens same as the schema, + // CSV reader doesn't support partial results. All fields other than the field + // configured by `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => getCurrentInput, () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f23851655350a..61e22fac854f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -236,12 +236,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -316,12 +316,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    From 6a8abe29ef3369b387d9bc2ee3459a6611246ab1 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Wed, 28 Feb 2018 23:16:29 +0800 Subject: [PATCH 095/593] [SPARK-23508][CORE] Fix BlockmanagerId in case blockManagerIdCache cause oom MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … cause oom ## What changes were proposed in this pull request? blockManagerIdCache in BlockManagerId will not remove old values which may cause oom `val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()` Since whenever we apply a new BlockManagerId, it will put into this map. This patch will use guava cahce for blockManagerIdCache instead. A heap dump show in [SPARK-23508](https://issues.apache.org/jira/browse/SPARK-23508) ## How was this patch tested? Exist tests. Author: zhoukang Closes #20667 from caneGuy/zhoukang/fix-history. --- .../org/apache/spark/storage/BlockManagerId.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 2c3da0ee85e06..d4a59c33b974c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -18,7 +18,8 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap + +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi @@ -132,10 +133,17 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + /** + * The max cache size is hardcoded to 10000, since the size of a BlockManagerId + * object is about 48B, the total memory cost should be below 1MB which is feasible. + */ + val blockManagerIdCache = CacheBuilder.newBuilder() + .maximumSize(10000) + .build(new CacheLoader[BlockManagerId, BlockManagerId]() { + override def load(id: BlockManagerId) = id + }) def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) blockManagerIdCache.get(id) } } From fab563b9bd1581112462c0fc0b299ad6510b6564 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 1 Mar 2018 00:44:13 +0900 Subject: [PATCH 096/593] [SPARK-23517][PYTHON] Make `pyspark.util._exception_message` produce the trace from Java side by Py4JJavaError ## What changes were proposed in this pull request? This PR proposes for `pyspark.util._exception_message` to produce the trace from Java side by `Py4JJavaError`. Currently, in Python 2, it uses `message` attribute which `Py4JJavaError` didn't happen to have: ```python >>> from pyspark.util import _exception_message >>> try: ... sc._jvm.java.lang.String(None) ... except Exception as e: ... pass ... >>> e.message '' ``` Seems we should use `str` instead for now: https://github.com/bartdag/py4j/blob/aa6c53b59027925a426eb09b58c453de02c21b7c/py4j-python/src/py4j/protocol.py#L412 but this doesn't address the problem with non-ascii string from Java side - `https://github.com/bartdag/py4j/issues/306` So, we could directly call `__str__()`: ```python >>> e.__str__() u'An error occurred while calling None.java.lang.String.\n: java.lang.NullPointerException\n\tat java.lang.String.(String.java:588)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)\n\tat sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)\n\tat java.lang.reflect.Constructor.newInstance(Constructor.java:422)\n\tat py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)\n\tat py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\n\tat py4j.Gateway.invoke(Gateway.java:238)\n\tat py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)\n\tat py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)\n\tat py4j.GatewayConnection.run(GatewayConnection.java:214)\n\tat java.lang.Thread.run(Thread.java:745)\n' ``` which doesn't type coerce unicodes to `str` in Python 2. This can be actually a problem: ```python from pyspark.sql.functions import udf spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.range(1).select(udf(lambda x: [[]])()).toPandas() ``` **Before** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` **After** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: An error occurred while calling o47.collectAsArrowToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 0.0 failed 1 times, most recent failure: Lost task 7.0 in stage 0.0 (TID 7, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/.../spark/python/pyspark/worker.py", line 245, in main process() File "/.../spark/python/pyspark/worker.py", line 240, in process ... Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20680 from HyukjinKwon/SPARK-23517. --- python/pyspark/tests.py | 11 +++++++++++ python/pyspark/util.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 511585763cb01..9111dbbed5929 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2293,6 +2293,17 @@ def set(self, x=None, other=None, other_x=None): self.assertEqual(b._x, 2) +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e5d332ce54429..ad4a0bc68ef41 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from py4j.protocol import Py4JJavaError __all__ = [] @@ -33,6 +34,12 @@ def _exception_message(excp): >>> msg == _exception_message(excp) True """ + if isinstance(excp, Py4JJavaError): + # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message' + # attribute in Python 2. We should call 'str' function on this exception in general but + # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work + # around by the direct call, '__str__()'. Please see SPARK-23517. + return excp.__str__() if hasattr(excp, "message"): return excp.message return str(excp) From 476a7f026bc45462067ebd39cd269147e84cd641 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 28 Feb 2018 08:44:53 -0800 Subject: [PATCH 097/593] [SPARK-23514] Use SessionState.newHadoopConf() to propage hadoop configs set in SQLConf. ## What changes were proposed in this pull request? A few places in `spark-sql` were using `sc.hadoopConfiguration` directly. They should be using `sessionState.newHadoopConf()` to blend in configs that were set through `SQLConf`. Also, for better UX, for these configs blended in from `SQLConf`, we should consider removing the `spark.hadoop` prefix, so that the settings are recognized whether or not they were specified by the user. ## How was this patch tested? Tested that AlterTableRecoverPartitions now correctly recognizes settings that are passed in to the FileSystem through SQLConf. Author: Juliusz Sompolski Closes #20679 from juliuszsompolski/SPARK-23514. --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0142f17ce62e2..964cbca049b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -610,10 +610,10 @@ case class AlterTableRecoverPartitionsCommand( val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val hadoopConf = spark.sessionState.newHadoopConf() + val fs = root.getFileSystem(hadoopConf) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt - val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) @@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand( pathFilter: PathFilter, threshold: Int): GenMap[String, PartitionStatistics] = { if (partitionSpecsAndLocs.length > threshold) { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 19028939f3673..fcf2025d34432 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -518,8 +518,9 @@ private[hive] class TestHiveSparkSession( // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to // delete it. Later, it will be re-created with the right permission. - val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) - val fs = location.getFileSystem(sc.hadoopConfiguration) + val hadoopConf = sessionState.newHadoopConf() + val location = new Path(hadoopConf.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(hadoopConf) fs.delete(location, true) // Some tests corrupt this value on purpose, which breaks the RESET call below. From 25c2776dd9ae3f9792048c78be2cbd958fd99841 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 28 Feb 2018 12:16:26 -0800 Subject: [PATCH 098/593] [SPARK-23523][SQL][FOLLOWUP] Minor refactor of OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? Inside `OptimizeMetadataOnlyQuery.getPartitionAttrs`, avoid using `zip` to generate attribute map. Also include other minor update of comments and format. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #20693 from jiangxb1987/SPARK-23523. --- .../spark/sql/execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../spark/sql/execution/datasources/HadoopFsRelation.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 0613d9053f826..dc4aff9f12580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -83,7 +83,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + val attrMap = relation.output.map(a => a.name.toLowerCase(Locale.ROOT) -> a).toMap partitionColumnNames.map { colName => attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), throw new AnalysisException(s"Unable to find the column `$colName` " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index ac574b07ec497..b2f73b7f8d1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,9 +67,9 @@ case class HadoopFsRelation( } } - // When data schema and partition schema have the overlapped columns, the output - // schema respects the order of data schema for the overlapped columns, but respect - // the data types of partition schema + // When data and partition schemas have overlapping columns, the output + // schema respects the order of the data schema for the overlapping columns, and it + // respects the data types of the partition schema. val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) From 22f3d3334c85c042c6e90f5a02f308d7cd1c1498 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 1 Mar 2018 14:28:28 +0800 Subject: [PATCH 099/593] [SPARK-23389][CORE] When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine =false`, we should be able to use serialized sorting. ## What changes were proposed in this pull request? When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine=false`, in the map side,there is no need for aggregation and sorting, so we should be able to use serialized sorting. ## How was this patch tested? Existing unit test Author: liuxian Closes #20576 from 10110346/mapsidecombine. --- .../scala/org/apache/spark/Dependency.scala | 3 +++ .../spark/shuffle/BlockStoreShuffleReader.scala | 1 - .../spark/shuffle/sort/SortShuffleManager.scala | 6 +++--- .../spark/shuffle/sort/SortShuffleWriter.scala | 2 -- .../shuffle/sort/SortShuffleManagerSuite.scala | 17 +++++++++-------- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ca52ecafa2cc8..9ea6d2fa2fd95 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -76,6 +76,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { + if (mapSideCombine) { + require(aggregator.isDefined, "Map-side combine without Aggregator specified!") + } override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff57c5..edd69715c9602 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -90,7 +90,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bfb4dc698e325..d9fad64f34c7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${dependency.serializer.getClass.getName}, does not support object relocation") false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + } else if (dependency.mapSideCombine) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") false } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792bf3..274399b9cc1f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { @@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter { def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") false } else { val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 55cebe7c8b6a8..f29dac965c803 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // We support serialized shuffle if we do not need to do map-side aggregation + assert(canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) } test("unsupported shuffle dependencies for serialized shuffle") { @@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform aggregation - assert(!canUseSerializedShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = None, - aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), - mapSideCombine = false - ))) + // We do not support serialized shuffle if we need to do map-side aggregation assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, From ff1480189b827af0be38605d566a4ee71b4c36f6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Mar 2018 16:26:11 +0800 Subject: [PATCH 100/593] [SPARK-23510][SQL] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? This is based on https://github.com/apache/spark/pull/20668 for supporting Hive 2.2 and Hive 2.3 metastore. When we merge the PR, we should give the major credit to wangyum ## How was this patch tested? Added the test cases Author: Yuming Wang Author: gatorsmile Closes #20671 from gatorsmile/pr-20668. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../sql/hive/client/HiveClientImpl.scala | 3 +- .../spark/sql/hive/client/HiveShim.scala | 8 +-- .../hive/client/IsolatedClientLoader.scala | 2 + .../spark/sql/hive/client/package.scala | 10 +++- .../sql/hive/execution/SaveAsHiveFile.scala | 3 +- .../sql/hive/client/HiveClientVersions.scala | 3 +- .../sql/hive/client/HiveVersionSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 51 +++++++++++++++++-- 9 files changed, 72 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c448c5a9821be..10c9603745379 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.1.1.") + s"0.12.0 through 2.3.2.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 146fa54a1bce4..da9fe2d3088b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf @@ -104,6 +103,8 @@ private[hive] class HiveClientImpl( case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() case hive.v2_1 => new Shim_v2_1() + case hive.v2_2 => new Shim_v2_2() + case hive.v2_3 => new Shim_v2_3() } // Create an internal session state for this HiveClientImpl. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19cd..948ba542b5733 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -880,9 +880,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { } -private[client] class Shim_v1_0 extends Shim_v0_14 { - -} +private[client] class Shim_v1_0 extends Shim_v0_14 private[client] class Shim_v1_1 extends Shim_v1_0 { @@ -1146,3 +1144,7 @@ private[client] class Shim_v2_1 extends Shim_v2_0 { alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) } } + +private[client] class Shim_v2_2 extends Shim_v2_1 + +private[client] class Shim_v2_3 extends Shim_v2_1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index dac0e333b63bc..12975bc85b971 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -97,6 +97,8 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.2" | "1.2.0" | "1.2.1" | "1.2.2" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 + case "2.2" | "2.2.0" => hive.v2_2 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index c14154a3b3c21..681ee9200f02b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -71,7 +71,15 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) + case object v2_2 extends HiveVersion("2.2.0", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v2_3 extends HiveVersion("2.3.2", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index e484356906e87..6a7b25b36d9a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -114,7 +114,8 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = + Set(v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala index 2e7dfde8b2fa5..30592a3f85428 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -22,5 +22,6 @@ import scala.collection.immutable.IndexedSeq import org.apache.spark.SparkFunSuite private[client] trait HiveClientVersions { - protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + protected val versions = + IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index a70fb6464cc1d..e5963d03f6b52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -34,7 +34,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 72536b833481a..6176273c88db1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -110,7 +111,8 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + private val versions = + Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") private var client: HiveClient = null @@ -125,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } @@ -422,15 +424,18 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") + val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - val partition = CatalogTablePartition(spec, storage) + val partition = CatalogTablePartition(spec, storage, parameters) client.alterPartitions("default", "src_part", Seq(partition)) assert(client.getPartition("default", "src_part", spec) .storage.locationUri == Some(newLocation)) + assert(client.getPartition("default", "src_part", spec) + .parameters.get(StatsSetupConst.TOTAL_SIZE) == Some("0")) } test(s"$version: dropPartitions") { @@ -633,6 +638,46 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: CREATE Partitioned TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql( + """ + |CREATE TABLE tbl(c1 string) + |PARTITIONED BY (ds STRING) + """.stripMargin) + versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") + + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) + val partMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty && numFiles.isEmpty) + } else { + assert(totalSize.nonEmpty && numFiles.nonEmpty) + } + + versionSpark.sql( + """ + |ALTER TABLE tbl PARTITION (ds='2') + |SET SERDEPROPERTIES ('newKey' = 'vvv') + """.stripMargin) + val newPartMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + + val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(newTotalSize.isEmpty && newNumFiles.isEmpty) + } else { + assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) + } + } + } + test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { From cdcccd7b41c43d79edff2fec7a84cd00e9524f75 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei <584620569@qq.com> Date: Fri, 2 Mar 2018 00:09:44 +0800 Subject: [PATCH 101/593] [SPARK-23405] Generate additional constraints for Join's children ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) I run a sql: `select ls.cs_order_number from ls left semi join catalog_sales cs on ls.cs_order_number = cs.cs_order_number`, The `ls` table is a small table ,and the number is one. The `catalog_sales` table is a big table, and the number is 10 billion. The task will be hang up. And i find the many null values of `cs_order_number` in the `catalog_sales` table. I think the null value should be removed in the logical plan. >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > +- MetastoreRelation 100t, catalog_sales Now, use this patch, the plan will be: >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > : **+- Filter isnotnull(cs_order_number#22)** > :+- MetastoreRelation 100t, catalog_sales ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: KaiXinXiaoLei <584620569@qq.com> Author: hanghang <584620569@qq.com> Closes #20670 from KaiXinXiaoLei/Spark-23405. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../plans/logical/QueryPlanConstraints.scala | 27 ++++++++++--------- .../InferFiltersFromConstraintsSuite.scala | 12 +++++++++ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a28b6a0feb8f9..91208479be03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -661,7 +661,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe case join @ Join(left, right, joinType, conditionOpt) => // Only consider constraints that can be pushed down completely to either the left or the // right child - val constraints = join.constraints.filter { c => + val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) } // Remove those constraints that are already enforced by either the left or the right child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 5c7b8e5b97883..046848875548b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -23,25 +23,28 @@ import org.apache.spark.sql.catalyst.expressions._ trait QueryPlanConstraints { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. + * An [[ExpressionSet]] that contains an additional set of constraints, such as equality + * constraints and `isNotNull` constraints, etc. */ - lazy val constraints: ExpressionSet = { + lazy val allConstraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet( - validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints)) - .filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - } - ) + ExpressionSet(validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints))) } else { ExpressionSet(Set.empty) } } + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + }) + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 178c4b8c270a0..f78c2356e35a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -192,4 +192,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } + + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftSemi, condition).analyze + val left = x.where(IsNotNull('a)) + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftSemi, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 34811e0b908449fd59bca476604612b1d200778d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 1 Mar 2018 17:26:39 -0800 Subject: [PATCH 102/593] [SPARK-23551][BUILD] Exclude `hadoop-mapreduce-client-core` dependency from `orc-mapreduce` ## What changes were proposed in this pull request? This PR aims to prevent `orc-mapreduce` dependency from making IDEs and maven confused. **BEFORE** Please note that `2.6.4` at `Spark Project SQL`. ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.orc:orc-mapreduce:jar:nohive:1.4.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.6.4:compile ``` **AFTER** ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile ``` ## How was this patch tested? 1. Pass the Jenkins with `dev/test-dependencies.sh` with the existing dependencies. 2. Manually do the following and see the change. ``` mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ``` Author: Dongjoon Hyun Closes #20704 from dongjoon-hyun/SPARK-23551. --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index b8396166f6b1b..0a711f287a53f 100644 --- a/pom.xml +++ b/pom.xml @@ -1753,6 +1753,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-mapreduce-client-core + org.apache.orc orc-core From 119f6a0e4729aa952e811d2047790a32ee90bf69 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 1 Mar 2018 21:04:01 -0800 Subject: [PATCH 103/593] [SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * BinarizerSuite * BucketedRandomProjectionLSHSuite * BucketizerSuite * ChiSqSelectorSuite * CountVectorizerSuite * DCTSuite.scala * ElementwiseProductSuite * FeatureHasherSuite * HashingTFSuite ## How was this patch tested? It tests itself because it is a bunch of tests! Author: Joseph K. Bradley Closes #20111 from jkbradley/SPARK-22883-streaming-featureAM. --- .../spark/ml/feature/BinarizerSuite.scala | 8 ++-- .../BucketedRandomProjectionLSHSuite.scala | 26 ++++++++--- .../spark/ml/feature/BucketizerSuite.scala | 11 +++-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 36 +++++++-------- .../ml/feature/CountVectorizerSuite.scala | 23 +++++----- .../apache/spark/ml/feature/DCTSuite.scala | 14 +++--- .../ml/feature/ElementwiseProductSuite.scala | 30 ++++++++++--- .../spark/ml/feature/FeatureHasherSuite.scala | 45 +++++++++---------- .../spark/ml/feature/HashingTFSuite.scala | 34 ++++++++------ 9 files changed, 126 insertions(+), 101 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 4455d35210878..05d4a6ee2dabf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BinarizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setInputCol("feature") .setOutputCol("binarized_feature") - binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") { case Row(x: Double, y: Double) => assert(x === y, "The feature value is not correct after binarization.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 7175c721bff36..ed9a39d8d1512 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -20,16 +20,15 @@ package org.apache.spark.ml.feature import breeze.numerics.{cos, sin} import breeze.numerics.constants.Pi -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} -class BucketedRandomProjectionLSHSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -98,6 +97,21 @@ class BucketedRandomProjectionLSHSuite MLTestingUtils.checkCopyAndUids(brp, brpModel) } + test("BucketedRandomProjectionLSH: streaming transform") { + val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val brpModel = brp.fit(dataset) + + testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") { + case Row(values: Seq[_]) => + assert(values.length === brp.getNumHashTables) + } + } + test("BucketedRandomProjectionLSH: test of LSH property") { // Project from 2 dimensional Euclidean Space to 1 dimensions val brp = new BucketedRandomProjectionLSH() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 41cf72fe3470a..9ea15e1918532 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(splits) bucketizer.setHandleInvalid("keep") - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c83909c4498f2..c843df9f33e3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - val model = ChiSqSelectorSuite.testSelector(selector, dataset) + val model = testSelector(selector, dataset) MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fpr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fdr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fwe") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("read/write") { @@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(expected.selectedFeatures === actual.selectedFeatures) } } -} -object ChiSqSelectorSuite { - - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { - val selectorModel = selector.fit(dataset) - selectorModel.transform(dataset).select("filtered", "topFeature").collect() - .foreach { case Row(vec1: Vector, vec2: Vector) => + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(data) + testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, + "filtered", "topFeature") { + case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) - } + } selectorModel } +} + +object ChiSqSelectorSuite { /** * Mapping from all Params to valid settings which differ from the defaults. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 1784c07ca23e3..61217669d9277 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(cv, cvm) assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cvm.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel2.vocabulary === Array("a", "b")) - cvModel2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel3.vocabulary === Array("a", "b")) - cvModel3.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -219,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -238,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(0.3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -258,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setOutputCol("features") .setBinary(true) .fit(df) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -268,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setBinary(true) - cv2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 8dd3dd75e1be5..6734336aac39c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -21,16 +21,14 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DCTSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setOutputCol("resultVec") .setInverse(inverse) - transformer.transform(dataset) - .select("resultVec", "wantedVec") - .collect() - .foreach { case Row(resultVec: Vector, wantedVec: Vector) => - assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") { + case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index a4cca27be7815..3a8d0762e2ab7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -17,13 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.Row -class ElementwiseProductSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + test("streaming transform") { + val scalingVec = Vectors.dense(0.1, 10.0) + val data = Seq( + (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)), + (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0)) + ) + val df = spark.createDataFrame(data).toDF("features", "expected") + val ep = new ElementwiseProduct() + .setInputCol("features") + .setOutputCol("actual") + .setScalingVec(scalingVec) + testTransformer[(Vector, Vector)](df, ep, "actual", "expected") { + case Row(actual: Vector, expected: Vector) => + assert(actual ~== expected relTol 1e-14) + } + } test("read/write") { val ep = new ElementwiseProduct() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 7bc1825b69c43..d799ba6011fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -17,27 +17,24 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FeatureHasherSuite extends SparkFunSuite - with MLlibTestSparkContext - with DefaultReadWriteTest { +class FeatureHasherSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import FeatureHasherSuite.murmur3FeatureIdx - implicit private val vectorEncoder = ExpressionEncoder[Vector]() + implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]() test("params") { ParamsSuite.checkParams(new FeatureHasher) @@ -52,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite } test("feature hashing") { + val numFeatures = 100 + // Assume perfect hash on field names in computing expected results + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val df = Seq( - (2.0, true, "1", "foo"), - (3.0, false, "2", "bar") - ).toDF("real", "bool", "stringNum", "string") + (2.0, true, "1", "foo", + Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), + (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))), + (3.0, false, "2", "bar", + Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), + (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))) + ).toDF("real", "bool", "stringNum", "string", "expected") - val n = 100 val hasher = new FeatureHasher() .setInputCols("real", "bool", "stringNum", "string") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hasher.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - assert(attrGroup.numAttributes === Some(n)) + assert(attrGroup.numAttributes === Some(numFeatures)) - val features = output.select("features").as[Vector].collect() - // Assume perfect hash on field names - def idx: Any => Int = murmur3FeatureIdx(n) - // check expected indices - val expected = Seq( - Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), - (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))), - Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), - (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))) - ) - assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14 ) + } } test("setting explicit numerical columns to treat as categorical") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index a46272fdce1fb..c5183ecfef7d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class HashingTFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import HashingTFSuite.murmur3FeatureIdx @@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") - val n = 100 + val numFeatures = 100 + // Assume perfect hash when computing expected features. + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val data = Seq( + ("a a b b c d".split(" ").toSeq, + Vectors.sparse(numFeatures, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))) + ) + + val df = data.toDF("words", "expected") val hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hashingTF.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = murmur3FeatureIdx(n) - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) + require(attrGroup.numAttributes === Some(numFeatures)) + + testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("applying binary term freqs") { From 0b6ceadeb563205cbd6bd03bc88e608086273b5b Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 2 Mar 2018 09:23:39 -0800 Subject: [PATCH 104/593] [SPARKR][DOC] fix link in vignettes ## What changes were proposed in this pull request? Fix doc link that was changed in 2.3 shivaram Author: Felix Cheung Closes #20711 from felixcheung/rvigmean. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index feca617c2554c..d4713de7806a1 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,7 +46,7 @@ Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ## Overview -SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](https://spark.apache.org/mllib/). ## Getting Started @@ -132,7 +132,7 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](https://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() @@ -147,7 +147,7 @@ sparkR.session(sparkHome = "/HOME/spark") ### Spark Session {#SetupSparkSession} -In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](https://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](https://spark.apache.org/docs/latest/api/R/sparkR.session.html). In particular, the following Spark driver properties can be set in `sparkConfig`. @@ -169,7 +169,7 @@ sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) #### Cluster Mode -SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with ```{r, echo=FALSE, tidy = TRUE} @@ -177,7 +177,7 @@ paste("Spark", packageVersion("SparkR")) ``` It should be used both on the local computer and on the remote cluster. -To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](https://spark.apache.org/docs/latest/submitting-applications.html#master-urls). For example, to connect to a local standalone Spark master, we can call ```{r, eval=FALSE} @@ -317,7 +317,7 @@ A common flow of grouping and aggregation is 2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. -A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there. For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. @@ -935,7 +935,7 @@ perplexity #### Alternating Least Squares -`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). +`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](https://dl.acm.org/citation.cfm?id=1608614). There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. @@ -1171,11 +1171,11 @@ env | map ## References -* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) +* [Spark Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) -* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) +* [Submitting Spark Applications](https://spark.apache.org/docs/latest/submitting-applications.html) -* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) +* [Machine Learning Library Guide (MLlib)](https://spark.apache.org/docs/latest/ml-guide.html) * [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. From 3a4d15e5d2b9ddbaeb2a6ab2d86d059ada6407b2 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 2 Mar 2018 10:38:50 -0800 Subject: [PATCH 105/593] [SPARK-23518][SQL] Avoid metastore access when the users only want to read and write data frames ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18944 added one patch, which allowed a spark session to be created when the hive metastore server is down. However, it did not allow running any commands with the spark session. This brings troubles to the user who only wants to read / write data frames without metastore setup. ## How was this patch tested? Added some unit tests to read and write data frames based on the original HiveMetastoreLazyInitializationSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20681 from liufengdb/completely-lazy. --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++---- .../sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../HiveMetastoreLazyInitializationSuite.scala | 14 ++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 8 ++++---- .../spark/sql/hive/HiveSessionStateBuilder.scala | 10 +++++----- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5197838eaac66..bd0a0dcd0674c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -67,6 +67,8 @@ sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) +# materialize the catalog implementation +listTables() mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 4b119c75260a7..64e7ca11270b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,8 +54,8 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - val externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -70,8 +70,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf) { this( - externalCatalog, - new GlobalTempViewManager("global_temp"), + () => externalCatalog, + () => new GlobalTempViewManager("global_temp"), functionRegistry, conf, new Configuration(), @@ -87,6 +87,9 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } + lazy val externalCatalog = externalCatalogBuilder() + lazy val globalTempViewManager = globalTempViewManagerBuilder() + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") protected val tempViews = new mutable.HashMap[String, LogicalPlan] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 007f8760edf82..3a0db7e16c23a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -130,8 +130,8 @@ abstract class BaseSessionStateBuilder( */ protected lazy val catalog: SessionCatalog = { val catalog = new SessionCatalog( - session.sharedState.externalCatalog, - session.sharedState.globalTempViewManager, + () => session.sharedState.externalCatalog, + () => session.sharedState.globalTempViewManager, functionRegistry, conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index 3f135cc864983..277df548aefd0 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -38,6 +38,20 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { // We should be able to run Spark jobs without Hive client. assert(spark.sparkContext.range(0, 1).count() === 1) + // We should be able to use Spark SQL if no table references. + assert(spark.sql("select 1 + 1").count() === 1) + assert(spark.range(0, 1).count() === 1) + + // We should be able to use fs + val path = Utils.createTempDir() + path.delete() + try { + spark.range(0, 1).write.parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).count() === 1) + } finally { + Utils.deleteRecursively(path) + } + // Make sure that we are not using the local derby metastore. val exceptionString = Utils.exceptionString(intercept[AnalysisException] { spark.sql("show tables") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1f11adbd4f62e..e5aff3b99d0b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalog: HiveExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => HiveExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, @@ -48,8 +48,8 @@ private[sql] class HiveSessionCatalog( parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( - externalCatalog, - globalTempViewManager, + externalCatalogBuilder, + globalTempViewManagerBuilder, functionRegistry, conf, hadoopConf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 12c74368dd184..40b9bb51ca9a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,8 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client - new HiveSessionResourceLoader(session, client) + new HiveSessionResourceLoader(session, () => externalCatalog.client) } /** @@ -51,8 +50,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session */ override protected lazy val catalog: HiveSessionCatalog = { val catalog = new HiveSessionCatalog( - externalCatalog, - session.sharedState.globalTempViewManager, + () => externalCatalog, + () => session.sharedState.globalTempViewManager, new HiveMetastoreCatalog(session), functionRegistry, conf, @@ -105,8 +104,9 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session class HiveSessionResourceLoader( session: SparkSession, - client: HiveClient) + clientBuilder: () => HiveClient) extends SessionResourceLoader(session) { + private lazy val client = clientBuilder() override def addJar(path: String): Unit = { client.addJar(path) super.addJar(path) From 707e6506d0dbdb598a6c99d666f3c66746113b67 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 2 Mar 2018 12:27:42 -0800 Subject: [PATCH 106/593] [SPARK-23097][SQL][SS] Migrate text socket source to V2 ## What changes were proposed in this pull request? This PR moves structured streaming text socket source to V2. Questions: do we need to remove old "socket" source? ## How was this patch tested? Unit test and manual verification. Author: jerryshao Closes #20382 from jerryshao/SPARK-23097. --- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../execution/datasources/DataSource.scala | 5 +- .../streaming/{ => sources}/socket.scala | 178 ++++++---- .../sql/streaming/DataStreamReader.scala | 21 +- .../streaming/TextSocketStreamSuite.scala | 231 ------------- .../sources/TextSocketStreamSuite.scala | 306 ++++++++++++++++++ 6 files changed, 434 insertions(+), 309 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => sources}/socket.scala (51%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 0259c774bbf4a..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 6e1b5727e3fd5..35fcff69b14d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -563,6 +564,7 @@ object DataSource extends Logging { val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -583,7 +585,8 @@ object DataSource extends Logging { "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, - "com.databricks.spark.csv" -> csv + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 0b22cbc46e6bf..5aae46b463398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -15,27 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -import org.apache.spark.unsafe.types.UTF8String - -object TextSocketSource { +object TextSocketMicroBatchReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -43,12 +45,17 @@ object TextSocketSource { } /** - * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. - * This source will *not* work in production applications due to multiple reasons, including no - * support for fault recovery and keeping all of the text read in memory forever. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt @GuardedBy("this") private var socket: Socket = null @@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(String, Timestamp)] @GuardedBy("this") - protected var currentOffset: LongOffset = new LongOffset(-1) + private var currentOffset: LongOffset = LongOffset(-1L) @GuardedBy("this") - protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + private var lastOffsetCommitted: LongOffset = LongOffset(-1L) initialize() + /** This method is only used for unit test */ + private[sources] def getCurrentOffset(): LongOffset = synchronized { + currentOffset.copy() + } + private def initialize(): Unit = synchronized { socket = new Socket(host, port) val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) @@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo logWarning(s"Stream closed by $host:$port") return } - TextSocketSource.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = (line, Timestamp.valueOf( - TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - ) - currentOffset = currentOffset + 1 + TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + currentOffset += 1 batches.append(newData) } } @@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo readThread.start() } - /** Returns the schema of the data from this source */ - override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP - else TextSocketSource.SCHEMA_REGULAR + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } + + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None + override def readSchema(): StructType = { + if (options.getBoolean("includeTimestamp", false)) { + TextSocketMicroBatchReader.SCHEMA_TIMESTAMP } else { - Some(currentOffset) + TextSocketMicroBatchReader.SCHEMA_REGULAR } } - /** Returns the data that is between the offsets (`start`, `end`]. */ - override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") + + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo batches.slice(sliceStart, sliceEnd) } - val rdd = sqlContext.sparkContext - .parallelize(rawList) - .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + assert(SparkSession.getActiveSession.isDefined) + val spark = SparkSession.getActiveSession.get + val numPartitions = spark.sparkContext.defaultParallelism + + val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + rawList.zipWithIndex.foreach { case (r, idx) => + slices(idx % numPartitions).append(r) + } + + (0 until numPartitions).map { i => + val slice = slices(i) + new DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new DataReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } + + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } + + override def close(): Unit = {} + } + } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo } } - override def toString: String = s"TextSocketSource[host: $host, port: $port]" + override def toString: String = s"TextSocket[host: $host, port: $port]" } -class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { - private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { - Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { - case Success(bool) => bool - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } +class TextSocketSourceProvider extends DataSourceV2 + with MicroBatchReadSupport with DataSourceRegister with Logging { - /** Returns the name and schema of the source that can be used to continually read data. */ - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { + private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!parameters.contains("host")) { + if (!params.get("host").isPresent) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!parameters.contains("port")) { + if (!params.get("port").isPresent) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - if (schema.nonEmpty) { - throw new AnalysisException("The socket source does not support a user-specified schema.") + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") } - - val sourceSchema = - if (parseIncludeTimestamp(parameters)) { - TextSocketSource.SCHEMA_TIMESTAMP - } else { - TextSocketSource.SCHEMA_REGULAR - } - ("textSocket", sourceSchema) } - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val host = parameters("host") - val port = parameters("port").toInt - new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 61e22fac854f9..c393dcdfdd7e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } ds match { case s: MicroBatchReadSupport => - val tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + var tempReader: MicroBatchReader = null + val schema = try { + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReader != null) { + tempReader.stop() + tempReader = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( Optional.ofNullable(userSpecifiedSchema.orNull), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala deleted file mode 100644 index ec11549073650..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io.{IOException, OutputStreamWriter} -import java.net.ServerSocket -import java.sql.Timestamp -import java.util.concurrent.LinkedBlockingQueue - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} - -class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { - import testImplicits._ - - override def afterEach() { - sqlContext.streams.active.foreach(_.stop()) - if (serverThread != null) { - serverThread.interrupt() - serverThread.join() - serverThread = null - } - if (source != null) { - source.stop() - source = null - } - } - - private var serverThread: ServerThread = null - private var source: Source = null - - test("basic usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("timestamped usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString, - "includeTimestamp" -> "true") - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("params not given") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map()) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) - } - } - - test("non-boolean includeTimestamp") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost", - "port" -> "1234", "includeTimestamp" -> "fasle")) - } - } - - test("user-specified schema given") { - val provider = new TextSocketSourceProvider - val userSpecifiedSchema = StructType( - StructField("name", StringType) :: - StructField("area", StringType) :: Nil) - val exception = intercept[AnalysisException] { - provider.sourceSchema( - sqlContext, Some(userSpecifiedSchema), - "", - Map("host" -> "localhost", "port" -> "1234")) - } - assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) - } - - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - source = provider.createSource(sqlContext, "", None, "", parameters) - } - } - - test("input row metrics") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) - } - source.stop() - source = null - } - } - - private class ServerThread extends Thread with Logging { - private val serverSocket = new ServerSocket(0) - private val messageQueue = new LinkedBlockingQueue[String]() - - val port = serverSocket.getLocalPort - - override def run(): Unit = { - try { - val clientSocket = serverSocket.accept() - clientSocket.setTcpNoDelay(true) - val out = new OutputStreamWriter(clientSocket.getOutputStream) - while (true) { - val line = messageQueue.take() - out.write(line + "\n") - out.flush() - } - } catch { - case e: InterruptedException => - } finally { - serverSocket.close() - } - } - - def enqueue(line: String): Unit = { - messageQueue.put(line) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala new file mode 100644 index 0000000000000..a15a980bb92fd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.IOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.ServerSocketChannel +import java.sql.Timestamp +import java.util.Optional +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } + } + + private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null + + case class AddSocketData(data: String*) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active socket source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + } + if (sources.isEmpty) { + throw new Exception( + "Could not find socket source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the socket source in the StreamExecution logical plan as there" + + "are multiple socket sources:\n\t" + sources.mkString("\n\t")) + } + val socketSource = sources.head + + assert(serverThread != null && serverThread.port != 0) + val currOffset = socketSource.getCurrentOffset() + data.foreach(serverThread.enqueue) + + val newOffset = LongOffset(currOffset.offset + data.size) + (socketSource, newOffset) + } + + override def toString: String = s"AddSocketData(data = $data)" + } + + test("backward compatibility with old path") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[TextSocketSourceProvider]) + case _ => + throw new IllegalStateException("Could not find socket source") + } + } + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + } + } + + test("timestamped usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val socket = spark + .readStream + .format("socket") + .options(Map( + "host" -> "localhost", + "port" -> serverThread.port.toString, + "includeTimestamp" -> "true")) + .load() + + assert(socket.schema === StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil)) + + var batch1Stamp: Timestamp = null + var batch2Stamp: Timestamp = null + + val curr = System.currentTimeMillis() + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "hello") + batch1Stamp = rows.head.getAs[Timestamp](1) + Thread.sleep(10) + }, + true), + AddSocketData("world"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "world") + batch2Stamp = rows.head.getAs[Timestamp](1) + }, + true), + StopStream + ) + + // Timestamp for rate stream is round to second which leads to milliseconds lost, that will + // make batch1stamp smaller than current timestamp if both of them are in the same second. + // Comparing by second to make sure the correct behavior. + assert(batch1Stamp.getTime >= curr / 1000 * 1000) + assert(!batch2Stamp.before(batch1Stamp)) + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) + } + } + + test("non-boolean includeTimestamp") { + val provider = new TextSocketSourceProvider + val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") + intercept[AnalysisException] { + val a = new DataSourceOptions(params.asJava) + provider.createMicroBatchReader(Optional.empty(), "", a) + } + } + + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val params = Map("host" -> "localhost", "port" -> "1234") + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + batchReader = provider.createMicroBatchReader( + Optional.empty(), "", new DataSourceOptions(parameters.asJava)) + } + } + + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AssertOnQuery { q => + val numRowMetric = + q.lastExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + numRowMetric.nonEmpty && numRowMetric.get.value == 1 + }, + StopStream + ) + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(0)) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocketChannel.socket().getLocalPort + + override def run(): Unit = { + try { + while (true) { + val clientSocketChannel = serverSocketChannel.accept() + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + + // Check whether remote client is closed but still send data to this closed socket. + // This happens in DataStreamReader where a source will be created to get the schema. + var remoteIsClosed = false + var cnt = 0 + while (cnt < 3 && !remoteIsClosed) { + if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { + cnt += 1 + Thread.sleep(100) + } else { + remoteIsClosed = true + } + } + + if (remoteIsClosed) { + logInfo(s"remote client ${clientSocketChannel.socket()} is closed") + } else { + while (true) { + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) + } + } + } + } catch { + case e: InterruptedException => + } finally { + serverSocketChannel.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} From 487377e693af65b2ff3d6b874ca7326c1ff0076c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 2 Mar 2018 14:30:37 -0800 Subject: [PATCH 107/593] [SPARK-23570][SQL] Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite since Spark 2.3.0 is released for ensuring backward compatibility. ## How was this patch tested? N/A Author: gatorsmile Closes #20720 from gatorsmile/add2.3. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index c13a750dbb270..6ca58e68d31eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") protected var spark: SparkSession = _ From 9e26473c0f29ee4281519104ac5e182a3bd4bf23 Mon Sep 17 00:00:00 2001 From: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Date: Fri, 2 Mar 2018 16:24:29 -0800 Subject: [PATCH 108/593] [SPARK-3159][ML] Add decision tree pruning ## What changes were proposed in this pull request? Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode ## How was this patch tested? Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala": - test("SPARK-3159 tree model redundancy - classification") - test("SPARK-3159 tree model redundancy - regression") 4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra _prune_ parameter which can be used to disable pruning for testing) Author: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Closes #20632 from asolimando/master. --- .../scala/org/apache/spark/ml/tree/Node.scala | 22 ++-- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../DecisionTreeClassifierSuite.scala | 38 ------- .../ml/tree/impl/RandomForestSuite.scala | 100 ++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 10 +- 5 files changed, 115 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..d30be452a436e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, - InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. @@ -266,15 +265,23 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { + def toNode: Node = toNode(prune = true) + /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode: Node = { - if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + def toNode(prune: Boolean = true): Node = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + case (l, r) => + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) + } } else { if (stats.valid) { new LeafNode(stats.impurityCalculator.predict, stats.impurity, @@ -283,7 +290,6 @@ private[tree] class LearningNode( // Here we want to keep same behavior with the old mllib.DecisionTreeModel new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553b..8e514f11e78ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation[_]], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 98c879ece62d6..38b265d62611b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite dt.fit(df) } - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val data = sc.parallelize(arr) - val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val dt = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(1) - .setMaxBins(3) - val model = dt.fit(df) - model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case other => - fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") - } - case other => - fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") - } - } - test("Feature importance with toy data") { val dt = new DecisionTreeClassifier() .setImpurity("gini") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931fb9c..5f0d26eb5c058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + private val seed = 42 + ///////////////////////////////////////////////////////////////////////////// // Tests for split calculation ///////////////////////////////////////////////////////////////////////////// @@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42, instr = None, prune = false).head + model.rootNode match { case n: InternalNode => n.split match { case s: CategoricalSplit => @@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip Vectors.sparse(size, indices.toArray, values.toArray) } + + @tailrec + private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { + if (nodes.isEmpty) { + acc + } + else { + nodes.head match { + case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 441d0f7614bf6..bc59f3f4125fb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { - if (i < 1000) { + if (i < 1001) { arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) From dea381dfaa73e0cfb9a833b79c741b15ae274f64 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Sat, 3 Mar 2018 09:10:48 +0800 Subject: [PATCH 109/593] [SPARK-23514][FOLLOW-UP] Remove more places using sparkContext.hadoopConfiguration directly ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/20679 I missed a few places in SQL tests. For hygiene, they should also use the sessionState interface where possible. ## How was this patch tested? Modified existing tests. Author: Juliusz Sompolski Closes #20718 from juliuszsompolski/SPARK-23514-followup. --- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/execution/datasources/FileIndexSuite.scala | 2 +- .../execution/datasources/parquet/ParquetCommitterSuite.scala | 2 +- .../datasources/parquet/ParquetFileFormatSuite.scala | 4 ++-- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- .../sql/execution/datasources/parquet/ParquetQuerySuite.scala | 2 +- .../org/apache/spark/sql/streaming/FileStreamSinkSuite.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b5d4c558f0d3e..73e3df3b6202e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -124,7 +124,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) val thirdPath = new Path(basePath, "third") - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = thirdPath.getFileSystem(spark.sessionState.newHadoopConf()) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b800e6ff5b0ce..db9023b7ec8b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1052,7 +1052,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index b4616826e40b3..18bb4bfe661ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -59,7 +59,7 @@ class FileIndexSuite extends SharedSQLContext { require(!unqualifiedDirPath.toString.contains("file:")) require(!unqualifiedFilePath.toString.contains("file:")) - val fs = unqualifiedDirPath.getFileSystem(sparkContext.hadoopConfiguration) + val fs = unqualifiedDirPath.getFileSystem(spark.sessionState.newHadoopConf()) val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index caa4f6d70c6a9..f3ecc5ced689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -101,7 +101,7 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils if (check) { result = Some(MarkingFileOutput.checkMarker( destPath, - spark.sparkContext.hadoopConfiguration)) + spark.sessionState.newHadoopConf())) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index ccb34355f1bac..3a0867fd2b78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -29,7 +29,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo test("read parquet footers in parallel") { def testReadFooters(ignoreCorruptFiles: Boolean): Unit = { withTempDir { dir => - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val basePath = dir.getCanonicalPath val path1 = new Path(basePath, "first") @@ -44,7 +44,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten val footers = ParquetFileFormat.readParquetFootersInParallel( - sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles) + spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles) assert(footers.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index e3edafa9c84e1..fbd83a0fa425a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -163,7 +163,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // Just to be defensive in case anything ever changes in parquet, this test checks // the assumption on column stats, and also the end-to-end behavior. - val hadoopConf = sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val fs = FileSystem.get(hadoopConf) val parts = fs.listStatus(new Path(tableDir.getAbsolutePath), new PathFilter { override def accept(path: Path): Boolean = !path.getName.startsWith("_") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 55b0f729be8ce..e1f094d0a7af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -819,7 +819,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val path = dir.getCanonicalPath spark.range(3).write.parquet(path) - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val files = fs.listFiles(new Path(path), true) while (files.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index ba48bc1ce0c4d..31e5527d7366a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -353,7 +353,7 @@ class FileStreamSinkSuite extends StreamTest { } test("FileStreamSink.ancestorIsMetadataDirectory()") { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() def assertAncestorIsMetadataDirectory(path: String): Unit = assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) def assertAncestorIsNotMetadataDirectory(path: String): Unit = From 486f99eefead4e664a30a861eca65cab8568e70b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Mar 2018 18:14:13 -0800 Subject: [PATCH 110/593] [SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions ## What changes were proposed in this pull request? Currently, when the Kafka source reads from Kafka, it generates as many tasks as the number of partitions in the topic(s) to be read. In some case, it may be beneficial to read the data with greater parallelism, that is, with more number partitions/tasks. That means, offset ranges must be divided up into smaller ranges such the number of records in partition ~= total records in batch / desired partitions. This would also balance out any data skews between topic-partitions. In this patch, I have added a new option called `minPartitions`, which allows the user to specify the desired level of parallelism. ## How was this patch tested? New tests in KafkaMicroBatchV2SourceSuite. Author: Tathagata Das Closes #20698 from tdas/SPARK-23541. --- .../sql/kafka010/KafkaMicroBatchReader.scala | 109 ++++++------- .../kafka010/KafkaOffsetRangeCalculator.scala | 105 +++++++++++++ .../sql/kafka010/KafkaSourceProvider.scala | 7 + .../apache/spark/sql/kafka010/package.scala | 24 +++ .../kafka010/KafkaMicroBatchSourceSuite.scala | 56 ++++++- .../KafkaOffsetRangeCalculatorSuite.scala | 147 ++++++++++++++++++ 6 files changed, 388 insertions(+), 60 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index fb647ca7e70dd..8a5f3a249b11c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import org.apache.commons.io.IOUtils -import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -64,8 +63,6 @@ private[kafka010] class KafkaMicroBatchReader( failOnDataLoss: Boolean) extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - type PartitionOffsetMap = Map[TopicPartition, Long] - private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -76,6 +73,7 @@ private[kafka010] class KafkaMicroBatchReader( private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + private val rangeCalculator = KafkaOffsetRangeCalculator(options) /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -106,15 +104,15 @@ private[kafka010] class KafkaMicroBatchReader( override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) - val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) - if (newPartitionOffsets.keySet != newPartitions) { + val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionInitialOffsets.keySet != newPartitions) { // We cannot get from offsets for some partitions. It means they got deleted. - val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) reportDataLoss( s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") } - logInfo(s"Partitions added: $newPartitionOffsets") - newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + logInfo(s"Partitions added: $newPartitionInitialOffsets") + newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => reportDataLoss( s"Added partition $p starts from $o instead of 0. Some data may have been missed") } @@ -125,46 +123,28 @@ private[kafka010] class KafkaMicroBatchReader( reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") } - // Use the until partitions to calculate offset ranges to ignore partitions that have + // Use the end partitions to calculate offset ranges to ignore partitions that have // been deleted val topicPartitions = endPartitionOffsets.keySet.filter { tp => // Ignore partitions that we don't know the from offsets. - newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) }.toSeq logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) - val sortedExecutors = getSortedExecutorList() - val numExecutors = sortedExecutors.length - logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) - // Calculate offset ranges - val factories = topicPartitions.flatMap { tp => - val fromOffset = startPartitionOffsets.get(tp).getOrElse { - newPartitionOffsets.getOrElse( - tp, { - // This should not happen since newPartitionOffsets contains all partitions not in - // fromPartitionOffsets - throw new IllegalStateException(s"$tp doesn't have a from offset") - }) - } - val untilOffset = endPartitionOffsets(tp) - - if (untilOffset >= fromOffset) { - // This allows cached KafkaConsumers in the executors to be re-used to read the same - // partition in every batch. - val preferredLoc = if (numExecutors > 0) { - Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) - } else None - val range = KafkaOffsetRange(tp, fromOffset, untilOffset) - Some( - new KafkaMicroBatchDataReaderFactory( - range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) - } else { - reportDataLoss( - s"Partition $tp's offset was changed from " + - s"$fromOffset to $untilOffset, some data may have been missed") - None - } + val offsetRanges = rangeCalculator.getRanges( + fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, + untilOffsets = endPartitionOffsets, + executorLocations = getSortedExecutorList()) + + // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, + // that is, concurrent tasks will not read the same TopicPartitions. + val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size + + // Generate factories based on the offset ranges + val factories = offsetRanges.map { range => + new KafkaMicroBatchDataReaderFactory( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava } @@ -320,28 +300,39 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReaderFactory( - range: KafkaOffsetRange, - preferredLoc: Option[String], +private[kafka010] case class KafkaMicroBatchDataReaderFactory( + offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { - override def preferredLocations(): Array[String] = preferredLoc.toArray + override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } /** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReader( +private[kafka010] case class KafkaMicroBatchDataReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = { + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We + // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + } - private val consumer = CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -369,9 +360,14 @@ private[kafka010] class KafkaMicroBatchDataReader( } override def close(): Unit = { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + if (!reuseKafkaConsumer) { + // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! + consumer.close() + } else { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { @@ -392,12 +388,9 @@ private[kafka010] class KafkaMicroBatchDataReader( } else { range.untilOffset } - KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None) } else { range } } } - -private[kafka010] case class KafkaOffsetRange( - topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala new file mode 100644 index 0000000000000..6631ae84167c8 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.sources.v2.DataSourceOptions + + +/** + * Class to calculate offset ranges to process based on the the from and until offsets, and + * the configured `minPartitions`. + */ +private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { + require(minPartitions.isEmpty || minPartitions.get > 0) + + import KafkaOffsetRangeCalculator._ + /** + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` + * is not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If + * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up + * the read tasks of the skewed partitions to multiple Spark tasks. + * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more + * depending on rounding errors or Kafka partitions that didn't receive any new data. + */ + def getRanges( + fromOffsets: PartitionOffsetMap, + untilOffsets: PartitionOffsetMap, + executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = { + val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet) + + val offsetRanges = partitionsToRead.toSeq.map { tp => + KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None) + }.filter(_.size > 0) + + // If minPartitions not set or there are enough partitions to satisfy minPartitions + if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) { + // Assign preferred executor locations to each range such that the same topic-partition is + // preferentially read from the same executor and the KafkaConsumer can be reused. + offsetRanges.map { range => + range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) + } + } else { + + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = offsetRanges.map(_.size).sum + val idealRangeSize = totalSize.toDouble / minPartitions.get + + offsetRanges.flatMap { range => + // Split the current range into subranges as close to the ideal range size + val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt + + (0 until numSplitsInRange).map { i => + val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange) + val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange) + KafkaOffsetRange( + range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None) + } + } + } + } + + private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = { + def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b + + val numExecutors = executorLocations.length + if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(executorLocations(floorMod(tp.hashCode, numExecutors))) + } else None + } +} + +private[kafka010] object KafkaOffsetRangeCalculator { + + def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + new KafkaOffsetRangeCalculator(optionalValue) + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + lazy val size: Long = untilOffset - fromOffset +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 0aa64a6a9cf90..36b9f0466566b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -348,6 +348,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister throw new IllegalArgumentException("Unknown option") } + // Validate minPartitions value if present + if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { + val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt + if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") + } + // Validate user-specified Kafka options if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -455,6 +461,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + private val MIN_PARTITIONS_OPTION_KEY = "minpartitions" val TOPIC_OPTION_KEY = "topic" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala new file mode 100644 index 0000000000000..43acd6a8d9473 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.kafka.common.TopicPartition + +package object kafka010 { // scalastyle:ignore + // ^^ scalastyle:ignore is for ignoring warnings about digits in package name + type PartitionOffsetMap = Map[TopicPartition, Long] +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 89c9ef4cc73b5..f2b3ff7615e74 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Properties} +import java.util.{Locale, Optional, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.io.Source import scala.util.Random @@ -34,15 +35,19 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -642,6 +647,53 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { } ) } + + testWithUninterruptibleThread("minPartitions is supported") { + import testImplicits._ + + val topic = newTopic() + val tp = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + def test( + minPartitions: String, + numPartitionsGenerated: Int, + reusesConsumers: Boolean): Unit = { + + SparkSession.setActiveSession(spark) + withTempDir { dir => + val provider = new KafkaSourceProvider() + val options = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "subscribe" -> topic + ) ++ Option(minPartitions).map { p => "minPartitions" -> p} + val reader = provider.createMicroBatchReader( + Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + reader.setOffsetRange( + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) + ) + val factories = reader.createUnsafeRowReaderFactories().asScala + .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { + assert(factories.size == numPartitionsGenerated) + factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + } + } + } + + // Test cases when minPartitions is used and not used + test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false) + + // Test illegal minPartitions values + intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } + } + } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala new file mode 100644 index 0000000000000..2ccf3e291bea7 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.DataSourceOptions + +class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { + + def testWithMinPartitions(name: String, minPartition: Int) + (f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + test(s"with minPartition = $minPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + + test("with no minPartition: N TopicPartitions to N offset ranges") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2), + executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 2, Some("location")))) + } + + test("with no minPartition: empty ranges ignored") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + } + + testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp2, 1, 2, None), + KafkaOffsetRange(tp3, 1, 2, None))) + } + + testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5), + executorLocations = Seq("location")) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set + } + + testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 11)) == + Seq( + KafkaOffsetRange(tp1, 1, 4, None), + KafkaOffsetRange(tp1, 4, 7, None), + KafkaOffsetRange(tp1, 7, 11, None))) + } + + testWithMinPartitions("empty ranges ignored", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + private val tp1 = new TopicPartition("t1", 1) + private val tp2 = new TopicPartition("t2", 1) + private val tp3 = new TopicPartition("t3", 1) +} From a89cdf55fa76fa23a524f0443e323498c3cc8664 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 5 Mar 2018 07:32:24 +0900 Subject: [PATCH 111/593] [SQL][MINOR] XPathDouble prettyPrint should say 'double' not 'float' ## What changes were proposed in this pull request? It looks like this was incorrectly copied from `XPathFloat` in the class above. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Liang Closes #20730 from ericl/fix-typo-xpath. --- .../org/apache/spark/sql/catalyst/expressions/xml/xpath.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index d0185562c9cfc..aacf1a44e2ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -160,7 +160,7 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { """) // scalastyle:on line.size.limit case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { - override def prettyName: String = "xpath_float" + override def prettyName: String = "xpath_double" override def dataType: DataType = DoubleType override def nullSafeEval(xml: Any, path: Any): Any = { From 7965c91d8a67c213ca5eebda5e46e7c49a8ba121 Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 5 Mar 2018 13:36:42 +0900 Subject: [PATCH 112/593] [SPARK-23569][PYTHON] Allow pandas_udf to work with python3 style type-annotated functions ## What changes were proposed in this pull request? Check python version to determine whether to use `inspect.getargspec` or `inspect.getfullargspec` before applying `pandas_udf` core logic to a function. The former is python2.7 (deprecated in python3) and the latter is python3.x. The latter correctly accounts for type annotations, which are syntax errors in python2.x. ## How was this patch tested? Locally, on python 2.7 and 3.6. Author: Michael (Stu) Stewart Closes #20728 from mstewart141/pandas_udf_fix. --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ python/pyspark/sql/udf.py | 9 ++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 19653072ea316..fa3b7203e10ac 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4381,6 +4381,24 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") + def test_type_annotation(self): + from pyspark.sql.functions import pandas_udf + # Regression test to check if type hints can be used. See SPARK-23569. + # Note that it throws an error during compilation in lower Python versions if 'exec' + # is not used. Also, note that we explicitly use another dictionary to avoid modifications + # in the current 'locals()'. + # + # Hyukjin: I think it's an ugly way to test issues about syntax specific in + # higher versions of Python, which we shouldn't encourage. This was the last resort + # I could come up with at that time. + _locals = {} + exec( + "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", + _locals) + df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + self.assertEqual(df.first()[0], 0) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e5b35fc60e167..b9b490874f4fb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -42,10 +42,17 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect + import sys from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - argspec = inspect.getargspec(f) + + if sys.version_info[0] < 3: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: From 269cd53590dd155aeb5269efc909a6e228f21e22 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Mar 2018 21:22:30 -0800 Subject: [PATCH 113/593] [MINOR][DOCS] Fix a link in "Compatibility with Apache Hive" ## What changes were proposed in this pull request? This PR fixes a broken link as below: **Before:** 2018-03-05 12 23 58 **After:** 2018-03-05 12 23 20 Also see https://spark.apache.org/docs/2.3.0/sql-programming-guide.html#compatibility-with-apache-hive ## How was this patch tested? Manually tested. I checked the same instances in `docs` directory. Seems this is the only one. Author: hyukjinkwon Closes #20733 from HyukjinKwon/minor-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c37c338a134f3..4d0f015f401bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2223,7 +2223,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 2ce37b50fc01558f49ad22f89c8659f50544ffec Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 5 Mar 2018 11:39:01 +0100 Subject: [PATCH 114/593] [SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext ## What changes were proposed in this pull request? A current `CodegenContext` class has immutable value or method without mutable state, too. This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20700 from kiszk/SPARK-23546. --- .../catalyst/expressions/BoundAttribute.scala | 9 +- .../spark/sql/catalyst/expressions/Cast.scala | 35 +- .../sql/catalyst/expressions/Expression.scala | 16 +- .../MonotonicallyIncreasingID.scala | 8 +- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../expressions/SparkPartitionID.scala | 7 +- .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 51 +- .../expressions/bitwiseExpressions.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 458 +++++++++--------- .../expressions/codegen/CodegenFallback.scala | 7 +- .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateOrdering.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 6 +- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/collectionOperations.scala | 6 +- .../expressions/complexTypeCreator.scala | 4 +- .../expressions/complexTypeExtractors.scala | 15 +- .../expressions/conditionalExpressions.scala | 10 +- .../expressions/datetimeExpressions.scala | 18 +- .../spark/sql/catalyst/expressions/hash.scala | 25 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 8 +- .../expressions/mathExpressions.scala | 5 +- .../expressions/nullExpressions.scala | 22 +- .../expressions/objects/objects.scala | 99 ++-- .../sql/catalyst/expressions/predicates.scala | 14 +- .../expressions/randomExpressions.scala | 8 +- .../expressions/regexpExpressions.scala | 8 +- .../expressions/stringExpressions.scala | 39 +- .../expressions/CodeGenerationSuite.scala | 4 +- .../sql/execution/ColumnarBatchScan.scala | 13 +- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 8 +- .../apache/spark/sql/execution/SortExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 16 +- .../aggregate/HashMapGenerator.scala | 8 +- .../aggregate/RowBasedHashMapGenerator.scala | 8 +- .../VectorizedHashMapGenerator.scala | 11 +- .../execution/basicPhysicalOperators.scala | 10 +- .../columnar/GenerateColumnAccessor.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 5 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- .../apache/spark/sql/execution/limit.scala | 7 +- 45 files changed, 535 insertions(+), 497 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef2..89ffbb0016916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 79b051670e9e4..12330bfa55ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultIsNull = $inputIsNull; - ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val funcName = ctx.freshName("elementToString") val elementToStringFunc = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(et)} element) { + |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { | UTF8String elementStr = null; | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} | return elementStr; @@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); | } | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { | $buffer.append(","); | if (!$array.isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); | } | } |} @@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val dataToStringCode = castToStringCode(dataType, ctx) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { | UTF8String dataStr = null; | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} | return dataStr; @@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) val loopIndex = ctx.freshName("loopIndex") + val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") + val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") + val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) s""" |$buffer.append("["); |if ($map.numElements() > 0) { - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append($keyToStringFunc($getMapFirstKey)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt(0)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | $buffer.append($valueToStringFunc($getMapFirstValue)); | } | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { | $buffer.append(", "); - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append($keyToStringFunc($getMapKeyArray)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc( - | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | $buffer.append($valueToStringFunc($getMapValueArray)); | } | } |} @@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | ${if (i != 0) s"""$buffer.append(" ");""" else ""} | | // Append $i field into the string buffer - | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; + ${CodeGenerator.javaType(fromType)} $fromElementPrim = + ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { @@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") val toFieldNull = ctx.freshName("tfn") - val fromType = ctx.javaType(from.fields(i).dataType) + val fromType = CodeGenerator.javaType(from.fields(i).dataType) s""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; + ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4568714933095..ed90b185865a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") + val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" @@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] { "" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) @@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression { ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { @@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression { boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = s""" @@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression { ${leftGen.code} ${midGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 11fb579dfa88c..4523079060896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 989c02305620a..e869258469a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1018,11 +1018,12 @@ case class ScalaUDF( val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" val resultConverter = s"$convertersTerm[${children.length}]" + val boxedType = CodeGenerator.boxedType(dataType) val callFunc = s""" - |${ctx.boxedType(dataType)} $resultTerm = null; + |$boxedType $resultTerm = null; |try { - | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); + | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); |} catch (Exception e) { | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} @@ -1035,7 +1036,7 @@ case class ScalaUDF( |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index a160b9b275290..cc6a769d032d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9a9f579b37f58..6c4a3601c1730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,7 +165,7 @@ case class PreciseTimestampConversion( val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; - |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6d7b..508bdd5050b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code s""" - ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); + ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); + ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } @@ -107,7 +107,7 @@ case class Abs(child: Expression) case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => @@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => @@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s"${eval2.value} == 0" } val remainder = ctx.freshName("remainder") - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val result = dataType match { case DecimalType.Fixed(_, _) => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value}); + $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); } else { @@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $remainder = - (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value}); + $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { - ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value}); + ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); } else { ${ev.value}=$remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value}; + $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; } else { @@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", @@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } @@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", @@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 173481f06a716..cc24e397cc14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60a6f50472504..793824b0b0a2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: String, var value: String) object ExprCode { + def forNullValue(dataType: DataType): ExprCode = { + val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + } + def forNonNullValue(value: String): ExprCode = { ExprCode(code = "", isNull = "false", value = value) } @@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec( */ class CodegenContext { + import CodeGenerator._ + /** * Holding a list of objects that could be used passed into generated class. */ @@ -196,11 +203,11 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { - if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" currentIndex += 1 res @@ -247,10 +254,10 @@ class CodegenContext { * are satisfied: * 1. forceInline is true * 2. its type is primitive type and the total number of the inlined mutable variables - * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * is less than `OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array * When a variable is compacted into an array, the max size of the array for compaction - * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -261,7 +268,7 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) @@ -339,7 +346,7 @@ class CodegenContext { val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { - CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + MUTABLESTATEARRAY_SIZE_LIMIT } if (javaType.contains("[]")) { // initializer had an one-dimensional array variable @@ -468,7 +475,7 @@ class CodegenContext { inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { + } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -537,14 +544,6 @@ class CodegenContext { extraClasses.append(code) } - final val JAVA_BOOLEAN = "boolean" - final val JAVA_BYTE = "byte" - final val JAVA_SHORT = "short" - final val JAVA_INT = "int" - final val JAVA_LONG = "long" - final val JAVA_FLOAT = "float" - final val JAVA_DOUBLE = "double" - /** * The map from a variable name to it's next ID. */ @@ -580,196 +579,6 @@ class CodegenContext { } } - /** - * Returns the specialized code to access a value from `inputRow` at `ordinal`. - */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" - } - } - - /** - * Returns the code to update a column in Row for a given DataType. - */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) - // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy - // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" - } - } - - /** - * Update a column in MutableRow from ExprCode. - * - * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise - */ - def updateColumn( - row: String, - dataType: DataType, - ordinal: Int, - ev: ExprCode, - nullable: Boolean, - isVectorized: Boolean = false): String = { - if (nullable) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - ${setColumn(row, dataType, ordinal, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - $row.setNullAt($ordinal); - } - """ - } - } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType`. - */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType` - * that could potentially be nullable. - */ - def updateColumn( - vector: String, - rowId: String, - dataType: DataType, - ev: ExprCode, - nullable: Boolean): String = { - if (nullable) { - s""" - if (!${ev.isNull}) { - ${setValue(vector, rowId, dataType, ev.value)} - } else { - $vector.putNull($rowId); - } - """ - } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" - } - } - - /** - * Returns the specialized code to access a value from a column vector for a given `DataType`. - */ - def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { - if (dataType.isInstanceOf[StructType]) { - // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an - // `ordinal` parameter. - s"$vector.getStruct($rowId)" - } else { - getValue(vector, dataType, rowId) - } - } - - /** - * Returns the name used in accessor and setter for a Java primitive type. - */ - def primitiveTypeName(jt: String): String = jt match { - case JAVA_INT => "Int" - case _ => boxedType(jt) - } - - def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) - - /** - * Returns the Java type for a DataType. - */ - def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case dt: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" - case ObjectType(cls) => cls.getName - case _ => "Object" - } - - /** - * Returns the boxed type in Java. - */ - def boxedType(jt: String): String = jt match { - case JAVA_BOOLEAN => "Boolean" - case JAVA_BYTE => "Byte" - case JAVA_SHORT => "Short" - case JAVA_INT => "Integer" - case JAVA_LONG => "Long" - case JAVA_FLOAT => "Float" - case JAVA_DOUBLE => "Double" - case other => other - } - - def boxedType(dt: DataType): String = boxedType(javaType(dt)) - - /** - * Returns the representation of default value for a given Java Type. - */ - def defaultValue(jt: String): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => "null" - } - - def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - /** * Generates code for equal expression in Java. */ @@ -812,6 +621,7 @@ class CodegenContext { val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") + val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -833,8 +643,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; - ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue("a", elementType, "i")}; + $jt $elementB = ${getValue("b", elementType, "i")}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -906,19 +716,6 @@ class CodegenContext { } } - /** - * List of java data types that have special accessors and setters in [[InternalRow]]. - */ - val primitiveTypes = - Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - - /** - * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. - */ - def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - - def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) - /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1089,7 +886,7 @@ class CodegenContext { // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse - if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) { // Adding a new function to each inner class which contains the invocation of all the // ones which have been added to that inner class. For example, // private class NestedClass { @@ -1289,7 +1086,7 @@ class CodegenContext { * length less than a pre-defined constant. */ def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH } } @@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging { result } }) + + /** + * Name of Java primitive data type + */ + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + + /** + * List of java primitive data types + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) + + /** + * Returns true if a Java type is Java primitive primitive type + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Returns the specialized code to access a value from `inputRow` at `ordinal`. + */ + def getValue(input: String, dataType: DataType, ordinal: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" + case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" + } + } + + /** + * Returns the code to update a column in Row for a given DataType. + */ + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" + case _ => s"$row.update($ordinal, $value)" + } + } + + /** + * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean, + isVectorized: Boolean = false): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | ${setColumn(row, dataType, ordinal, "null")}; + |} + """.stripMargin + } else { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | $row.setNullAt($ordinal); + |} + """.stripMargin + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType`. + */ + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _ => + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType` + * that could potentially be nullable. + */ + def updateColumn( + vector: String, + rowId: String, + dataType: DataType, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + s""" + |if (!${ev.isNull}) { + | ${setValue(vector, rowId, dataType, ev.value)} + |} else { + | $vector.putNull($rowId); + |} + """.stripMargin + } else { + s"""${setValue(vector, rowId, dataType, ev.value)};""" + } + } + + /** + * Returns the specialized code to access a value from a column vector for a given `DataType`. + */ + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) + } + } + + /** + * Returns the name used in accessor and setter for a Java primitive type. + */ + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } + + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) + + /** + * Returns the Java type for a DataType. + */ + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case _: DecimalType => "Decimal" + case BinaryType => "byte[]" + case StringType => "UTF8String" + case CalendarIntervalType => "CalendarInterval" + case _: StructType => "InternalRow" + case _: ArrayType => "ArrayData" + case _: MapType => "MapData" + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName + case _ => "Object" + } + + /** + * Returns the boxed type in Java. + */ + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + + /** + * Returns the representation of default value for a given Java Type. + * @param jt the string name of the Java type + * @param typedNull if true, for null literals, return a typed (with a cast) version + */ + def defaultValue(jt: String, typedNull: Boolean): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + + def defaultValue(dt: DataType, typedNull: Boolean = false): String = + defaultValue(javaType(dt), typedNull) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a9ff..e12420bb5dfdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -44,20 +44,21 @@ trait CodegenFallback extends Expression { } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) + val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; """, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7e2d..d35fd8ecb4d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) - val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; @@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => val ev = ExprCode("", isNull, value) - ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) + CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed634..9a51be6ed5aeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR s""" ${ctx.INPUT_ROW} = a; boolean $isNullA; - ${ctx.javaType(order.child.dataType)} $primitiveA; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; @@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } ${ctx.INPUT_ROW} = b; boolean $isNullB; - ${ctx.javaType(order.child.dataType)} $primitiveB; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba42a..f92f70ee71fef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] mutableRow.setNullAt($i); } else { ${converter.code} - ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b6..22717f5954a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } s""" @@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case other => other } - val jt = ctx.javaType(et) + val jt = CodeGenerator.javaType(et) val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize case _ => 8 // we need 8 bytes to store offset and length } val tmpCursor = ctx.freshName("tmpCursor") - val element = ctx.getValue(tmpInput, et, index) + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val primitiveTypeName = + if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..beb84694c44e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = "false") } } @@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 047b80ac5289c..85facdad43db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -90,7 +90,7 @@ private [sql] object GenArrayData { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length - if (!ctx.isPrimitiveType(elementType)) { + if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName @@ -124,7 +124,7 @@ private [sql] object GenArrayData { ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca3908905..6cdad19168dce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; } """ } else { s""" - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } }) @@ -205,7 +205,7 @@ case class GetArrayStructFields( } else { final InternalRow $row = $eval.getStruct($j, $numFields); $nullSafeEval { - $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; } } } @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ }) @@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression) } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression) int $index = 0; boolean $found = false; while ($index < $length && !$found) { - final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)}; if (${ctx.genEqual(keyType, key, eval2)}) { $found = true; } else { @@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression) if (!$found$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(values, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b444c3a7be92a..f4e9619bac59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" |${condEval.code} |boolean ${ev.isNull} = false; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -244,10 +244,10 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BYTE, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -264,7 +264,7 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 424871f2047e9..1ae4e5a2f716b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -673,18 +673,19 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode.forNullValue(dataType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; @@ -713,7 +714,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; }""") @@ -724,7 +725,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") @@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.$truncFuncStr; }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da54..b702422ed7a1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression { } } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", @@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression { ctx: CodegenContext): String = { val element = ctx.freshName("element") + val jt = CodeGenerator.javaType(elementType) ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } @@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression { val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", @@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, + extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$body |return ${ev.value}; """.stripMargin, @@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = s""" - |${ctx.JAVA_INT} ${ev.value} = $seed; - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes """.stripMargin) } @@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), - returnType = ctx.JAVA_INT, + arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childResult = 0; + |${CodeGenerator.JAVA_INT} $childResult = 0; |$body |return $result; """.stripMargin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed1757..07785e7448586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getInputFilePath();", isNull = "false") } } @@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getStartOffset();", isNull = "false") } } @@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getLength();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c1e65e34c2ea6..7395609a04ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) if (value == null) { - val defaultValueLiteral = ctx.defaultValue(javaType) match { - case "null" => s"(($javaType)null)" - case lit => lit - } - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode.forNullValue(dataType) } else { dataType match { case BooleanType | IntegerType | DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d8dc0862f1141..2c2cf3d2e6227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression, }""" } + val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 470d5da041ea5..b35fa72e95d1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", returnType = resultType, makeSplitFunction = func => s""" - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $codes |} while (false); @@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression case DoubleType | FloatType => ev.copy(code = s""" ${eval.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") } } @@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression) ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, + extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" |do { @@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate ev.copy(code = s""" - |${ctx.JAVA_INT} $nonnull = 0; + |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes |} while (false); - |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64da9bb9cdec1..80618af1e859f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") argValue } @@ -137,7 +137,7 @@ case class StaticInvoke( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -151,7 +151,7 @@ case class StaticInvoke( } val evaluate = if (returnNullable) { - if (ctx.defaultValue(dataType) == "null") { + if (CodeGenerator.defaultValue(dataType) == "null") { s""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; @@ -159,7 +159,7 @@ case class StaticInvoke( } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -173,7 +173,7 @@ case class StaticInvoke( val code = s""" $argCode $prepareIsNull - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!$resultIsNull) { $evaluate } @@ -228,7 +228,7 @@ case class Invoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -255,11 +255,11 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" } else { s""" if ($funcResult != null) { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; } else { ${ev.isNull} = true; } @@ -275,7 +275,7 @@ case class Invoke( val code = s""" ${obj.code} boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { $argCode ${ev.isNull} = $resultIsNull; @@ -341,7 +341,7 @@ case class NewInstance( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -358,7 +358,8 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ ev.copy(code = code) } @@ -385,15 +386,15 @@ case class UnwrapOption( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) val code = s""" ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -546,7 +547,7 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVarDataType) + val elementJavaType = CodeGenerator.javaType(loopVarDataType) ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -554,7 +555,7 @@ case class MapObjects private( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(lambdaFunction.dataType) + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -621,7 +622,7 @@ case class MapObjects private( ( s"${genInputData.value}.numElements()", "", - ctx.getValue(genInputData.value, et, loopIndex) + CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => val it = ctx.freshName("it") @@ -643,7 +644,8 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -695,7 +697,7 @@ case class MapObjects private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -806,10 +808,10 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = ctx.javaType(mapType.keyType) + val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(mapType.valueType) + val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) @@ -825,10 +827,11 @@ case class CatalystToExternalMap private( val valueArray = ctx.freshName("valueArray") val getKeyArray = s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" - val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" - val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + val getValueLoopVar = CodeGenerator.getValue( + valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -844,7 +847,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { @@ -873,7 +876,7 @@ case class CatalystToExternalMap private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = ctx.javaType(keyType) - val valueElementJavaType = ctx.javaType(valueType) + val keyElementJavaType = CodeGenerator.javaType(keyType) + val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) @@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private( val arrayCls = classOf[GenericArrayData].getName val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = ctx.boxedType(keyConverter.dataType) - val convertedValueType = ctx.boxedType(valueConverter.dataType) + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) val code = s""" ${inputMap.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // Code to deserialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val instanceGen = beanInstance.genCode(ctx) val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) val initialize = setters.map { case (setterMethod, fieldValue) => @@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: ArrayType => s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => - s"$obj instanceof ${ctx.boxedType(dataType)}" + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } val code = s""" ${input.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d00d4..4b85d9adbe311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = ctx.javaType(value.dataType) + val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, - returnType = ctx.JAVA_BYTE, + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" |do { @@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = s""" |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (ctx.isPrimitiveType(left.dataType) + if (CodeGenerator.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbfc31..6c9937dacc70b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = "false") } } @@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f3e8f6de58975..ad0c0791d895f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4c57..22fbb8998ed89 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } @@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression]) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") - val idxInVararg = ctx.freshName("idxInVararg") + val idxVararg = ctx.freshName("idxInVararg") val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => @@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression]) if (eval.isNull == "true") { "" } else { - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") @@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression]) if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")}; } } """) @@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression]) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body - |return $idxInVararg; + |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( s""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; - int $idxInVararg = 0; + int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; $varargBuilds @@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression { val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") + val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" @@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression { expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, - returnType = ctx.JAVA_BOOLEAN, + returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |do { | $body |} while (false); @@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression { s""" |${index.code} |final int $indexVal = ${index.value}; - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |$inputVal = null; |do { | $codes |} while (false); - |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; + |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => val value = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.value})" + s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + + s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } @@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); @@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression) val usLocale = "US" val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956ddc8..1e48c7b8df9da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { - ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;") } assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04f2619ed7541..392906a022903 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValueFromVector(columnVar, dataType, ordinal) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { s"$javaType $valueVar = $value;" @@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 + val scanTimeTotalNs = + ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93ecd..12ae1ea4a7c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -154,7 +154,8 @@ case class ExpandExec( val value = ctx.freshName("value") val code = s""" |boolean $isNull = true; - |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + |${CodeGenerator.javaType(firstExpr.dataType)} $value = + | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin ExprCode(code, isNull, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0c2c4a1a9100d..384f0398a1ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -305,15 +305,15 @@ case class GenerateExec( nullable: Boolean, initialChecks: Seq[String]): ExprCode = { val value = ctx.freshName(name) - val javaType = ctx.javaType(dt) - val getter = ctx.getValue(source, dt, index) + val javaType = CodeGenerator.javaType(dt) + val getter = CodeGenerator.getValue(source, dt, index) val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = s""" |boolean $isNull = ${checks.mkString(" || ")}; - |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, isNull, value) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ac1c34d41c4f1..0dc16ba5ce281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -133,7 +133,8 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index deb0a044c2fb2..f89e3fb0e536f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan { variables.zipWithIndex.foreach { case (ev, i) => val paramName = ctx.freshName(s"expr_$i") - val paramType = ctx.javaType(attributes(i).dataType) + val paramType = CodeGenerator.javaType(attributes(i).dataType) arguments += ev.value parameters += s"$paramType $paramName" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce3c68810f3b6..1926e9373bc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -186,8 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -532,7 +532,7 @@ case class HashAggregateExec( */ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = - (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) @@ -565,7 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -757,7 +757,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { @@ -832,7 +832,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -855,7 +855,7 @@ case class HashAggregateExec( } val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( + CodeGenerator.updateColumn( fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4ab1..6b60b414ffe5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -41,13 +41,13 @@ abstract class HashMapGenerator( val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ") val buffVars: Seq[ExprCode] = { val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index fd25707dd4ca6..8617be88f3570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.types._ /** @@ -114,7 +114,7 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", key.dataType, ordinal.toString()), key.name)})""" }.mkString(" && ") } @@ -147,7 +147,7 @@ class RowBasedHashMapGenerator( case t: DecimalType => s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" case t: DataType => - if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) { throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") } s"agg_rowWriter.write(${ordinal}, ${key.name})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 633eeac180974..7b3580cecc60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -127,7 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, + "buckets[idx]") s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } @@ -182,14 +183,14 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, - buffVars(ordinal), nullable = true) + CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", + key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a15a8d11aa2a0..4707022f74547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number") + val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") + val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") + val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4f28eeb725cbb..3b5655ba0582e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { - case t if ctx.isPrimitiveType(dt) => + case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 1918fcc5482db..487d6a2383318 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -182,9 +182,10 @@ case class BroadcastHashJoinExec( // the variables are needed even there is no matched rows val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) val code = s""" |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { | ${ev.code} | $isNull = ${ev.isNull}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d3..5a511b30e4fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -516,9 +516,9 @@ case class SortMergeJoinExec( ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - val javaType = ctx.javaType(a.dataType) - val defaultValue = ctx.defaultValue(a.dataType) + val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val javaType = CodeGenerator.javaType(a.dataType) + val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cccee63bc0680..66bcda8913738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils @@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false + val stopEarly = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; From 42cf48e20cd5e47e1b7557af9c71c4eea142f10f Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Mon, 5 Mar 2018 14:33:12 +0100 Subject: [PATCH 115/593] [SPARK-23496][CORE] Locality of coalesced partitions can be severely skewed by the order of input partitions ## What changes were proposed in this pull request? The algorithm in `DefaultPartitionCoalescer.setupGroups` is responsible for picking preferred locations for coalesced partitions. It analyzes the preferred locations of input partitions. It starts by trying to create one partition for each unique location in the input. However, if the the requested number of coalesced partitions is higher that the number of unique locations, it has to pick duplicate locations. Previously, the duplicate locations would be picked by iterating over the input partitions in order, and copying their preferred locations to coalesced partitions. If the input partitions were clustered by location, this could result in severe skew. With the fix, instead of iterating over the list of input partitions in order, we pick them at random. It's not perfectly balanced, but it's much better. ## How was this patch tested? Unit test reproducing the behavior was added. Author: Ala Luszczak Closes #20664 from ala/SPARK-23496. --- .../org/apache/spark/rdd/CoalescedRDD.scala | 8 ++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 10451a324b0f4..94e7d0b38cba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -266,17 +266,17 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) numCreated += 1 } } - tries = 0 // if we don't have enough partition groups, create duplicates while (numCreated < targetLen) { - val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) - tries += 1 + // Copy the preferred location from a random input partition. + // This helps in avoiding skew when the input partitions are clustered by preferred location. + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs( + rnd.nextInt(partitionLocs.partsWithLocs.length)) val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup addPartToPGroup(nxt_part, pgroup) numCreated += 1 - if (tries >= partitionLocs.partsWithLocs.length) tries = 0 } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462f..191c61250ce21 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -1129,6 +1129,35 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { }.collect() } + test("SPARK-23496: order of input partitions can result in severe skew in coalesce") { + val numInputPartitions = 100 + val numCoalescedPartitions = 50 + val locations = Array("locA", "locB") + + val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions) + assert(inputRDD.getNumPartitions == numInputPartitions) + + val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => + if (p.index < numCoalescedPartitions) { + Seq(locations(0)) + } else { + Seq(locations(1)) + } + }) + val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions) + + val numPartsPerLocation = coalescedRDD + .getPartitions + .map(coalescedRDD.getPreferredLocations(_).head) + .groupBy(identity) + .mapValues(_.size) + + // Make sure the coalesced partitions are distributed fairly evenly between the two locations. + // This should not become flaky since the DefaultPartitionsCoalescer uses a fixed seed. + assert(numPartsPerLocation(locations(0)) > 0.4 * numCoalescedPartitions) + assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because @@ -1210,3 +1239,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria groups.toArray } } + +/** Alters the preferred locations of the parent RDD using provided function. */ +class LocationPrefRDD[T: ClassTag]( + @transient var prev: RDD[T], + val locationPicker: Partition => Seq[String]) extends RDD[T](prev) { + override protected def getPartitions: Array[Partition] = prev.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = + null.asInstanceOf[Iterator[T]] + + override def getPreferredLocations(partition: Partition): Seq[String] = + locationPicker(partition) +} From 5ff72ffcf495d2823f7f1186078d1cb261667c3d Mon Sep 17 00:00:00 2001 From: Anirudh Date: Mon, 5 Mar 2018 23:17:16 +0900 Subject: [PATCH 116/593] [SPARK-23566][MINOR][DOC] Argument name mismatch fixed Argument name mismatch fixed. ## What changes were proposed in this pull request? `col` changed to `new` in doc string to match the argument list. Patch file added: https://issues.apache.org/jira/browse/SPARK-23566 Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Anirudh Closes #20716 from animenon/master. --- python/pyspark/sql/dataframe.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f37777e13ee12..9d8e85cde914f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -588,6 +588,8 @@ def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + :param numPartitions: int, to specify the target number of partitions + Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will @@ -612,9 +614,10 @@ def repartition(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions @@ -673,9 +676,10 @@ def repartitionByRange(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is range partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. @@ -892,6 +896,8 @@ def colRegex(self, colName): def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. + :param alias: string, an alias name to be set for the DataFrame. + >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") @@ -1900,7 +1906,7 @@ def withColumnRenamed(self, existing, new): This is a no-op if schema doesn't contain the given column name. :param existing: string, name of the existing column to rename. - :param col: string, new name of the column. + :param new: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] From a366b950b90650693ad0eb1e5b9a988ad028d845 Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Mon, 5 Mar 2018 23:46:40 +0900 Subject: [PATCH 117/593] [SPARK-23329][SQL] Fix documentation of trigonometric functions ## What changes were proposed in this pull request? Provide more details in trigonometric function documentations. Referenced `java.lang.Math` for further details in the descriptions. ## How was this patch tested? Ran full build, checked generated documentation manually Author: Mihaly Toth Closes #20618 from misutoth/trigonometric-doc. --- R/pkg/R/functions.R | 34 ++-- python/pyspark/sql/functions.py | 62 ++++--- .../expressions/mathExpressions.scala | 99 ++++++++--- .../org/apache/spark/sql/functions.scala | 160 ++++++++++++------ 4 files changed, 248 insertions(+), 107 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f7c6317cd924..29ee146ab14f9 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -278,8 +278,8 @@ setMethod("abs", }) #' @details -#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in -#' the range 0.0 through pi. +#' \code{acos}: Returns the inverse cosine of the given value, +#' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions #' @export @@ -334,8 +334,8 @@ setMethod("ascii", }) #' @details -#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in -#' the range -pi/2 through pi/2. +#' \code{asin}: Returns the inverse sine of the given value, +#' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions #' @export @@ -349,8 +349,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. +#' \code{atan}: Returns the inverse tangent of the given value, +#' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions #' @export @@ -613,7 +613,8 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. Units in radians. +#' \code{cos}: Returns the cosine of the given value, +#' as if computed by \code{java.lang.Math.cos()}. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -627,7 +628,8 @@ setMethod("cos", }) #' @details -#' \code{cosh}: Computes the hyperbolic cosine of the given value. +#' \code{cosh}: Returns the hyperbolic cosine of the given value, +#' as if computed by \code{java.lang.Math.cosh()}. #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method @@ -1463,7 +1465,8 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. Units in radians. +#' \code{sin}: Returns the sine of the given value, +#' as if computed by \code{java.lang.Math.sin()}. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1477,7 +1480,8 @@ setMethod("sin", }) #' @details -#' \code{sinh}: Computes the hyperbolic sine of the given value. +#' \code{sinh}: Returns the hyperbolic sine of the given value, +#' as if computed by \code{java.lang.Math.sinh()}. #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method @@ -1653,7 +1657,9 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. Units in radians. +#' \code{tan}: Returns the tangent of the given value, +#' as if computed by \code{java.lang.Math.tan()}. +#' Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1667,7 +1673,8 @@ setMethod("tan", }) #' @details -#' \code{tanh}: Computes the hyperbolic tangent of the given value. +#' \code{tanh}: Returns the hyperbolic tangent of the given value, +#' as if computed by \code{java.lang.Math.tanh()}. #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method @@ -1973,7 +1980,8 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). Units in radians. +#' (x, y) to polar coordinates (r, theta), +#' as if computed by \code{java.lang.Math.atan2()}. Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bb9c323a5a60..b9c0c57262c5d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -106,18 +106,15 @@ def _(): _functions_1_4 = { # unary math functions - 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + - '0.0 through pi.', - 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2', + 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', + 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', + 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': """Computes the cosine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'cos': """:param col: angle in radians + :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""", + 'cosh': """:param col: hyperbolic angle + :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""", 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', 'floor': 'Computes the floor of the given value.', @@ -127,14 +124,16 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': """Computes the sine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': """Computes the tangent of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'sin': """:param col: angle in radians + :return: sine of the angle, as if computed by `java.lang.Math.sin()`""", + 'sinh': """:param col: hyperbolic angle + :return: hyperbolic sine of the given value, + as if computed by `java.lang.Math.sinh()`""", + 'tan': """:param col: angle in radians + :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""", + 'tanh': """:param col: hyperbolic angle + :return: hyperbolic tangent of the given value, + as if computed by `java.lang.Math.tanh()`""", 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', @@ -173,16 +172,31 @@ def _(): _functions_2_1 = { # unary math functions - 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'degrees': """ + Converts an angle measured in radians to an approximately equivalent angle + measured in degrees. + :param col: angle in radians + :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()` + """, + 'radians': """ + Converts an angle measured in degrees to an approximately equivalent angle + measured in radians. + :param col: angle in degrees + :return: angle in radians, as if computed by `java.lang.Math.toRadians()` + """, } # math functions that take two arguments as input _binary_mathfunctions = { - 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta). Units in radians.', + 'atan2': """ + :param col1: coordinate on y-axis + :param col2: coordinate on x-axis + :return: the `theta` component of the point + (`r`, `theta`) + in polar coordinates that corresponds to the point + (`x`, `y`) in Cartesian coordinates, + as if computed by `java.lang.Math.atan2()` + """, 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 2c2cf3d2e6227..bc4cfcec47425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -168,9 +168,11 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -178,12 +180,13 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`, + as if computed by `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -191,18 +194,18 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).", + usage = """ + _FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by + `java.lang.Math._FUNC_` + """, examples = """ Examples: > SELECT _FUNC_(0); 0.0 """) -// scalastyle:on line.size.limit case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") @ExpressionDescription( @@ -252,7 +255,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -261,7 +271,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -512,7 +529,11 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sine of `expr`.", + usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -521,7 +542,13 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.", + usage = """ + _FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -539,7 +566,13 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -548,7 +581,13 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cotangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -562,7 +601,14 @@ case class Cot(child: Expression) } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -572,6 +618,10 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH" @ExpressionDescription( usage = "_FUNC_(expr) - Converts radians to degrees.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(3.141592653589793); @@ -583,6 +633,10 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre @ExpressionDescription( usage = "_FUNC_(expr) - Converts degrees to radians.", + arguments = """ + Arguments: + * expr - angle in degrees + """, examples = """ Examples: > SELECT _FUNC_(180); @@ -768,15 +822,22 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).", + usage = """ + _FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane + and the point given by the coordinates (`exprX`, `exprY`), as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * exprY - coordinate on y-axis + * exprX - coordinate on x-axis + """, examples = """ Examples: > SELECT _FUNC_(0, 0); 0.0 """) -// scalastyle:on line.size.limit case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d54c02c3d06f..c9ca9a8996344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1313,8 +1313,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the cosine inverse of the given value; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1322,8 +1321,7 @@ object functions { def acos(e: Column): Column = withExpr { Acos(e.expr) } /** - * Computes the cosine inverse of the given column; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1331,8 +1329,7 @@ object functions { def acos(columnName: String): Column = acos(Column(columnName)) /** - * Computes the sine inverse of the given value; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1340,8 +1337,7 @@ object functions { def asin(e: Column): Column = withExpr { Asin(e.expr) } /** - * Computes the sine inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1349,8 +1345,7 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `e`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1358,8 +1353,7 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1367,77 +1361,117 @@ object functions { def atan(columnName: String): Column = atan(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). Units in radians. + * @param y coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } + def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) } /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, rightName: String): Column = - atan2(Column(leftName), Column(rightName)) + def atan2(yName: String, xName: String): Column = + atan2(Column(yName), Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) + def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l), r) + def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) /** * An expression that returns the string representation of the binary value of the given long @@ -1500,7 +1534,8 @@ object functions { } /** - * Computes the cosine of the given value. Units in radians. + * @param e angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1508,7 +1543,8 @@ object functions { def cos(e: Column): Column = withExpr { Cos(e.expr) } /** - * Computes the cosine of the given column. + * @param columnName angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1516,7 +1552,8 @@ object functions { def cos(columnName: String): Column = cos(Column(columnName)) /** - * Computes the hyperbolic cosine of the given value. + * @param e hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1524,7 +1561,8 @@ object functions { def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** - * Computes the hyperbolic cosine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1967,7 +2005,8 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. Units in radians. + * @param e angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1975,7 +2014,8 @@ object functions { def sin(e: Column): Column = withExpr { Sin(e.expr) } /** - * Computes the sine of the given column. + * @param columnName angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1983,7 +2023,8 @@ object functions { def sin(columnName: String): Column = sin(Column(columnName)) /** - * Computes the hyperbolic sine of the given value. + * @param e hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1991,7 +2032,8 @@ object functions { def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** - * Computes the hyperbolic sine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1999,7 +2041,8 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. Units in radians. + * @param e angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2007,7 +2050,8 @@ object functions { def tan(e: Column): Column = withExpr { Tan(e.expr) } /** - * Computes the tangent of the given column. + * @param columnName angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2015,7 +2059,8 @@ object functions { def tan(columnName: String): Column = tan(Column(columnName)) /** - * Computes the hyperbolic tangent of the given value. + * @param e hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2023,7 +2068,8 @@ object functions { def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** - * Computes the hyperbolic tangent of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2047,6 +2093,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param e angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2055,6 +2104,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param columnName angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2077,6 +2129,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param e angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2085,6 +2140,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param columnName angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2873,7 +2931,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2929,7 +2987,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 From 947b4e6f09db6aa5d92409344b6e273e9faeb24e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 5 Mar 2018 16:21:02 +0100 Subject: [PATCH 118/593] [SPARK-23510][DOC][FOLLOW-UP] Update spark.sql.hive.metastore.version ## What changes were proposed in this pull request? Update `spark.sql.hive.metastore.version` to 2.3.2, same as HiveUtils.scala: https://github.com/apache/spark/blob/ff1480189b827af0be38605d566a4ee71b4c36f6/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala#L63-L65 ## How was this patch tested? N/A Author: Yuming Wang Closes #20734 from wangyum/SPARK-23510-FOLLOW-UP. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d0f015f401bb..01e2076555ee6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 1.2.1. + options are 0.12.0 through 2.3.2. From 4586eada42d6a16bb78d1650d145531c51fa747f Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 5 Mar 2018 09:30:49 -0800 Subject: [PATCH 119/593] [SPARK-22430][R][DOCS] Unknown tag warnings when building R docs with Roxygen 6.0.1 ## What changes were proposed in this pull request? Removed export tag to get rid of unknown tag warnings ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20501 from rekhajoshm/SPARK-22430. --- R/pkg/R/DataFrame.R | 92 --------- R/pkg/R/SQLContext.R | 16 -- R/pkg/R/WindowSpec.R | 8 - R/pkg/R/broadcast.R | 3 - R/pkg/R/catalog.R | 18 -- R/pkg/R/column.R | 7 - R/pkg/R/context.R | 6 - R/pkg/R/functions.R | 181 ----------------- R/pkg/R/generics.R | 343 --------------------------------- R/pkg/R/group.R | 7 - R/pkg/R/install.R | 1 - R/pkg/R/jvm.R | 3 - R/pkg/R/mllib_classification.R | 20 -- R/pkg/R/mllib_clustering.R | 23 --- R/pkg/R/mllib_fpm.R | 6 - R/pkg/R/mllib_recommendation.R | 5 - R/pkg/R/mllib_regression.R | 17 -- R/pkg/R/mllib_stat.R | 4 - R/pkg/R/mllib_tree.R | 33 ---- R/pkg/R/mllib_utils.R | 3 - R/pkg/R/schema.R | 7 - R/pkg/R/sparkR.R | 7 - R/pkg/R/stats.R | 6 - R/pkg/R/streaming.R | 9 - R/pkg/R/utils.R | 1 - R/pkg/R/window.R | 4 - 26 files changed, 830 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 41c3c3a89fa72..c4852024c0f49 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -36,7 +36,6 @@ setOldClass("structType") #' @slot sdf A Java object reference to the backing Scala DataFrame #' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -77,7 +76,6 @@ setWriteMode <- function(write, mode) { write } -#' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached #' @noRd @@ -97,7 +95,6 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @rdname printSchema #' @name printSchema #' @aliases printSchema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -123,7 +120,6 @@ setMethod("printSchema", #' @rdname schema #' @name schema #' @aliases schema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -146,7 +142,6 @@ setMethod("schema", #' @aliases explain,SparkDataFrame-method #' @rdname explain #' @name explain -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -178,7 +173,6 @@ setMethod("explain", #' @rdname isLocal #' @name isLocal #' @aliases isLocal,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -209,7 +203,6 @@ setMethod("isLocal", #' @aliases showDF,SparkDataFrame-method #' @rdname showDF #' @name showDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -241,7 +234,6 @@ setMethod("showDF", #' @rdname show #' @aliases show,SparkDataFrame-method #' @name show -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -269,7 +261,6 @@ setMethod("show", "SparkDataFrame", #' @rdname dtypes #' @name dtypes #' @aliases dtypes,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -296,7 +287,6 @@ setMethod("dtypes", #' @rdname columns #' @name columns #' @aliases columns,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -388,7 +378,6 @@ setMethod("colnames<-", #' @aliases coltypes,SparkDataFrame-method #' @name coltypes #' @family SparkDataFrame functions -#' @export #' @examples #'\dontrun{ #' irisDF <- createDataFrame(iris) @@ -445,7 +434,6 @@ setMethod("coltypes", #' @rdname coltypes #' @name coltypes<- #' @aliases coltypes<-,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -494,7 +482,6 @@ setMethod("coltypes<-", #' @rdname createOrReplaceTempView #' @name createOrReplaceTempView #' @aliases createOrReplaceTempView,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -521,7 +508,6 @@ setMethod("createOrReplaceTempView", #' @rdname registerTempTable-deprecated #' @name registerTempTable #' @aliases registerTempTable,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -552,7 +538,6 @@ setMethod("registerTempTable", #' @rdname insertInto #' @name insertInto #' @aliases insertInto,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -580,7 +565,6 @@ setMethod("insertInto", #' @aliases cache,SparkDataFrame-method #' @rdname cache #' @name cache -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -611,7 +595,6 @@ setMethod("cache", #' @rdname persist #' @name persist #' @aliases persist,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -641,7 +624,6 @@ setMethod("persist", #' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -669,7 +651,6 @@ setMethod("unpersist", #' @rdname storageLevel #' @aliases storageLevel,SparkDataFrame-method #' @name storageLevel -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -707,7 +688,6 @@ setMethod("storageLevel", #' @name coalesce #' @aliases coalesce,SparkDataFrame-method #' @seealso \link{repartition} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -744,7 +724,6 @@ setMethod("coalesce", #' @name repartition #' @aliases repartition,SparkDataFrame-method #' @seealso \link{coalesce} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -793,7 +772,6 @@ setMethod("repartition", #' @rdname toJSON #' @name toJSON #' @aliases toJSON,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -826,7 +804,6 @@ setMethod("toJSON", #' @rdname write.json #' @name write.json #' @aliases write.json,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -858,7 +835,6 @@ setMethod("write.json", #' @aliases write.orc,SparkDataFrame,character-method #' @rdname write.orc #' @name write.orc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -890,7 +866,6 @@ setMethod("write.orc", #' @rdname write.parquet #' @name write.parquet #' @aliases write.parquet,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -911,7 +886,6 @@ setMethod("write.parquet", #' @rdname write.parquet #' @name saveAsParquetFile #' @aliases saveAsParquetFile,SparkDataFrame,character-method -#' @export #' @note saveAsParquetFile since 1.4.0 setMethod("saveAsParquetFile", signature(x = "SparkDataFrame", path = "character"), @@ -936,7 +910,6 @@ setMethod("saveAsParquetFile", #' @aliases write.text,SparkDataFrame,character-method #' @rdname write.text #' @name write.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -963,7 +936,6 @@ setMethod("write.text", #' @aliases distinct,SparkDataFrame-method #' @rdname distinct #' @name distinct -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1004,7 +976,6 @@ setMethod("unique", #' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1061,7 +1032,6 @@ setMethod("sample_frac", #' @rdname nrow #' @name nrow #' @aliases count,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1094,7 +1064,6 @@ setMethod("nrow", #' @rdname ncol #' @name ncol #' @aliases ncol,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1118,7 +1087,6 @@ setMethod("ncol", #' @rdname dim #' @aliases dim,SparkDataFrame-method #' @name dim -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1144,7 +1112,6 @@ setMethod("dim", #' @rdname collect #' @aliases collect,SparkDataFrame-method #' @name collect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1229,7 +1196,6 @@ setMethod("collect", #' @rdname limit #' @name limit #' @aliases limit,SparkDataFrame,numeric-method -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -1253,7 +1219,6 @@ setMethod("limit", #' @rdname take #' @name take #' @aliases take,SparkDataFrame,numeric-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1282,7 +1247,6 @@ setMethod("take", #' @aliases head,SparkDataFrame-method #' @rdname head #' @name head -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1307,7 +1271,6 @@ setMethod("head", #' @aliases first,SparkDataFrame-method #' @rdname first #' @name first -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1359,7 +1322,6 @@ setMethod("toRDD", #' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy #' @name groupBy -#' @export #' @examples #' \dontrun{ #' # Compute the average for all numeric columns grouped by department. @@ -1401,7 +1363,6 @@ setMethod("group_by", #' @aliases agg,SparkDataFrame-method #' @rdname summarize #' @name agg -#' @export #' @note agg since 1.4.0 setMethod("agg", signature(x = "SparkDataFrame"), @@ -1460,7 +1421,6 @@ setClassUnion("characterOrstructType", c("character", "structType")) #' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1519,7 +1479,6 @@ setMethod("dapply", #' @aliases dapplyCollect,SparkDataFrame,function-method #' @name dapplyCollect #' @seealso \link{dapply} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1576,7 +1535,6 @@ setMethod("dapplyCollect", #' @rdname gapply #' @name gapply #' @seealso \link{gapplyCollect} -#' @export #' @examples #' #' \dontrun{ @@ -1673,7 +1631,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @name gapplyCollect #' @seealso \link{gapply} -#' @export #' @examples #' #' \dontrun{ @@ -1947,7 +1904,6 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected #' columns. -#' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method #' @seealso \link{withColumn} @@ -1992,7 +1948,6 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' If more than one column is assigned in \code{col}, \code{...} #' should be left empty. #' @return A new SparkDataFrame with selected columns. -#' @export #' @family SparkDataFrame functions #' @rdname select #' @aliases select,SparkDataFrame,character-method @@ -2024,7 +1979,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "character"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,Column-method #' @note select(SparkDataFrame, Column) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "Column"), @@ -2037,7 +1991,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "Column"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,list-method #' @note select(SparkDataFrame, list) since 1.4.0 setMethod("select", @@ -2066,7 +2019,6 @@ setMethod("select", #' @aliases selectExpr,SparkDataFrame,character-method #' @rdname selectExpr #' @name selectExpr -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2098,7 +2050,6 @@ setMethod("selectExpr", #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} \link{subset} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2137,7 +2088,6 @@ setMethod("withColumn", #' @rdname mutate #' @name mutate #' @seealso \link{rename} \link{withColumn} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2208,7 +2158,6 @@ setMethod("mutate", }) #' @param _data a SparkDataFrame. -#' @export #' @rdname mutate #' @aliases transform,SparkDataFrame-method #' @name transform @@ -2232,7 +2181,6 @@ setMethod("transform", #' @name withColumnRenamed #' @aliases withColumnRenamed,SparkDataFrame,character,character-method #' @seealso \link{mutate} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2258,7 +2206,6 @@ setMethod("withColumnRenamed", #' @rdname rename #' @name rename #' @aliases rename,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2304,7 +2251,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @aliases arrange,SparkDataFrame,Column-method #' @rdname arrange #' @name arrange -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2335,7 +2281,6 @@ setMethod("arrange", #' @rdname arrange #' @name arrange #' @aliases arrange,SparkDataFrame,character-method -#' @export #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), @@ -2368,7 +2313,6 @@ setMethod("arrange", #' @rdname arrange #' @aliases orderBy,SparkDataFrame,characterOrColumn-method -#' @export #' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 setMethod("orderBy", signature(x = "SparkDataFrame", col = "characterOrColumn"), @@ -2389,7 +2333,6 @@ setMethod("orderBy", #' @rdname filter #' @name filter #' @family subsetting functions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2432,7 +2375,6 @@ setMethod("where", #' @aliases dropDuplicates,SparkDataFrame-method #' @rdname dropDuplicates #' @name dropDuplicates -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2481,7 +2423,6 @@ setMethod("dropDuplicates", #' @rdname join #' @name join #' @seealso \link{merge} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2533,7 +2474,6 @@ setMethod("join", #' @rdname crossJoin #' @name crossJoin #' @seealso \link{merge} \link{join} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2581,7 +2521,6 @@ setMethod("crossJoin", #' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge #' @seealso \link{join} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2721,7 +2660,6 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { #' @name union #' @aliases union,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2742,7 +2680,6 @@ setMethod("union", #' @rdname union #' @name unionAll #' @aliases unionAll,SparkDataFrame,SparkDataFrame-method -#' @export #' @note unionAll since 1.4.0 setMethod("unionAll", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2769,7 +2706,6 @@ setMethod("unionAll", #' @name unionByName #' @aliases unionByName,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{union} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2802,7 +2738,6 @@ setMethod("unionByName", #' @rdname rbind #' @name rbind #' @seealso \link{union} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2835,7 +2770,6 @@ setMethod("rbind", #' @aliases intersect,SparkDataFrame,SparkDataFrame-method #' @rdname intersect #' @name intersect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2863,7 +2797,6 @@ setMethod("intersect", #' @aliases except,SparkDataFrame,SparkDataFrame-method #' @rdname except #' @name except -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2872,7 +2805,6 @@ setMethod("intersect", #' exceptDF <- except(df, df2) #' } #' @rdname except -#' @export #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2909,7 +2841,6 @@ setMethod("except", #' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2944,7 +2875,6 @@ setMethod("write.df", #' @rdname write.df #' @name saveDF #' @aliases saveDF,SparkDataFrame,character-method -#' @export #' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), @@ -2978,7 +2908,6 @@ setMethod("saveDF", #' @aliases saveAsTable,SparkDataFrame,character-method #' @rdname saveAsTable #' @name saveAsTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3015,7 +2944,6 @@ setMethod("saveAsTable", #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname describe #' @name describe -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3071,7 +2999,6 @@ setMethod("describe", #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3117,7 +3044,6 @@ setMethod("summary", #' @rdname nafunctions #' @aliases dropna,SparkDataFrame-method #' @name dropna -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3148,7 +3074,6 @@ setMethod("dropna", #' @rdname nafunctions #' @name na.omit #' @aliases na.omit,SparkDataFrame-method -#' @export #' @note na.omit since 1.5.0 setMethod("na.omit", signature(object = "SparkDataFrame"), @@ -3168,7 +3093,6 @@ setMethod("na.omit", #' @rdname nafunctions #' @name fillna #' @aliases fillna,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3399,7 +3323,6 @@ setMethod("str", #' @rdname drop #' @name drop #' @aliases drop,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3427,7 +3350,6 @@ setMethod("drop", #' @name drop #' @rdname drop #' @aliases drop,ANY-method -#' @export setMethod("drop", signature(x = "ANY"), function(x) { @@ -3446,7 +3368,6 @@ setMethod("drop", #' @rdname histogram #' @aliases histogram,SparkDataFrame,characterOrColumn-method #' @family SparkDataFrame functions -#' @export #' @examples #' \dontrun{ #' @@ -3582,7 +3503,6 @@ setMethod("histogram", #' @rdname write.jdbc #' @name write.jdbc #' @aliases write.jdbc,SparkDataFrame,character,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3611,7 +3531,6 @@ setMethod("write.jdbc", #' @aliases randomSplit,SparkDataFrame,numeric-method #' @rdname randomSplit #' @name randomSplit -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3645,7 +3564,6 @@ setMethod("randomSplit", #' @aliases getNumPartitions,SparkDataFrame-method #' @rdname getNumPartitions #' @name getNumPartitions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3672,7 +3590,6 @@ setMethod("getNumPartitions", #' @rdname isStreaming #' @name isStreaming #' @seealso \link{read.stream} \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3726,7 +3643,6 @@ setMethod("isStreaming", #' @aliases write.stream,SparkDataFrame-method #' @rdname write.stream #' @name write.stream -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3819,7 +3735,6 @@ setMethod("write.stream", #' @rdname checkpoint #' @name checkpoint #' @seealso \link{setCheckpointDir} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") @@ -3847,7 +3762,6 @@ setMethod("checkpoint", #' @aliases localCheckpoint,SparkDataFrame-method #' @rdname localCheckpoint #' @name localCheckpoint -#' @export #' @examples #'\dontrun{ #' df <- localCheckpoint(df) @@ -3874,7 +3788,6 @@ setMethod("localCheckpoint", #' @aliases cube,SparkDataFrame-method #' @rdname cube #' @name cube -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3909,7 +3822,6 @@ setMethod("cube", #' @aliases rollup,SparkDataFrame-method #' @rdname rollup #' @name rollup -#' @export #' @examples #'\dontrun{ #' df <- createDataFrame(mtcars) @@ -3942,7 +3854,6 @@ setMethod("rollup", #' @aliases hint,SparkDataFrame,character-method #' @rdname hint #' @name hint -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3966,7 +3877,6 @@ setMethod("hint", #' @family SparkDataFrame functions #' @rdname alias #' @name alias -#' @export #' @examples #' \dontrun{ #' df <- alias(createDataFrame(mtcars), "mtcars") @@ -3997,7 +3907,6 @@ setMethod("alias", #' @family SparkDataFrame functions #' @rdname broadcast #' @name broadcast -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -4041,7 +3950,6 @@ setMethod("broadcast", #' @family SparkDataFrame functions #' @rdname withWatermark #' @name withWatermark -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9d0a2d5e074e4..ebec0ce3d1920 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -123,7 +123,6 @@ infer_type <- function(x) { #' @return a list of config values with keys as their names #' @rdname sparkR.conf #' @name sparkR.conf -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -163,7 +162,6 @@ sparkR.conf <- function(key, defaultValue) { #' @return a character string of the Spark version #' @rdname sparkR.version #' @name sparkR.version -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -191,7 +189,6 @@ getDefaultSqlSource <- function() { #' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -294,7 +291,6 @@ createDataFrame <- function(x, ...) { #' @rdname createDataFrame #' @aliases createDataFrame -#' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { @@ -304,7 +300,6 @@ as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPa #' @param ... additional argument(s). #' @rdname createDataFrame #' @aliases as.DataFrame -#' @export as.DataFrame <- function(data, ...) { dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } @@ -342,7 +337,6 @@ setMethod("toDF", signature(x = "RDD"), #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -371,7 +365,6 @@ read.json <- function(x, ...) { #' @rdname read.json #' @name jsonFile -#' @export #' @method jsonFile default #' @note jsonFile since 1.4.0 jsonFile.default <- function(path) { @@ -423,7 +416,6 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc -#' @export #' @name read.orc #' @note read.orc since 2.0.0 read.orc <- function(path, ...) { @@ -444,7 +436,6 @@ read.orc <- function(path, ...) { #' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @export #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 @@ -466,7 +457,6 @@ read.parquet <- function(x, ...) { #' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile -#' @export #' @method parquetFile default #' @note parquetFile since 1.4.0 parquetFile.default <- function(...) { @@ -490,7 +480,6 @@ parquetFile <- function(x, ...) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -522,7 +511,6 @@ read.text <- function(x, ...) { #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame #' @rdname sql -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -556,7 +544,6 @@ sql <- function(x, ...) { #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -591,7 +578,6 @@ tableToDF <- function(tableName) { #' @rdname read.df #' @name read.df #' @seealso \link{read.json} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -681,7 +667,6 @@ loadDF <- function(x = NULL, ...) { #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -734,7 +719,6 @@ read.jdbc <- function(url, tableName, #' @rdname read.stream #' @name read.stream #' @seealso \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index debc7cbde55e7..ee7f4adf726e6 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{windowPartitionBy}, \link{windowOrderBy} #' #' @param sws A Java object reference to the backing Scala WindowSpec -#' @export #' @note WindowSpec since 2.0.0 setClass("WindowSpec", slots = list(sws = "jobj")) @@ -44,7 +43,6 @@ windowSpec <- function(sws) { } #' @rdname show -#' @export #' @note show(WindowSpec) since 2.0.0 setMethod("show", "WindowSpec", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "WindowSpec", #' @name partitionBy #' @aliases partitionBy,WindowSpec-method #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' partitionBy(ws, "col1", "col2") @@ -97,7 +94,6 @@ setMethod("partitionBy", #' @aliases orderBy,WindowSpec,character-method #' @family windowspec_method #' @seealso See \link{arrange} for use in sorting a SparkDataFrame -#' @export #' @examples #' \dontrun{ #' orderBy(ws, "col1", "col2") @@ -113,7 +109,6 @@ setMethod("orderBy", #' @rdname orderBy #' @name orderBy #' @aliases orderBy,WindowSpec,Column-method -#' @export #' @note orderBy(WindowSpec, Column) since 2.0.0 setMethod("orderBy", signature(x = "WindowSpec", col = "Column"), @@ -142,7 +137,6 @@ setMethod("orderBy", #' @aliases rowsBetween,WindowSpec,numeric,numeric-method #' @name rowsBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rowsBetween(ws, 0, 3) @@ -174,7 +168,6 @@ setMethod("rowsBetween", #' @aliases rangeBetween,WindowSpec,numeric,numeric-method #' @name rangeBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rangeBetween(ws, 0, 3) @@ -202,7 +195,6 @@ setMethod("rangeBetween", #' @name over #' @aliases over,Column,WindowSpec-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 398dffc4ab1b4..282f8a6857738 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -32,14 +32,12 @@ # @seealso broadcast # # @param id Id of the backing Spark broadcast variable -# @export setClass("Broadcast", slots = list(id = "character")) # @rdname broadcast-class # @param value Value of the broadcast variable # @param jBroadcastRef reference to the backing Java broadcast object # @param objName name of broadcasted object -# @export Broadcast <- function(id, value, jBroadcastRef, objName) { .broadcastValues[[id]] <- value .broadcastNames[[as.character(objName)]] <- jBroadcastRef @@ -73,7 +71,6 @@ setMethod("value", # @param bcastId The id of broadcast variable to set # @param value The value to be set -# @export setBroadcastValue <- function(bcastId, value) { bcastIdStr <- as.character(bcastId) .broadcastValues[[bcastIdStr]] <- value diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index e59a7024333ac..baf4d861fcf86 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -34,7 +34,6 @@ #' @return A SparkDataFrame. #' @rdname createExternalTable-deprecated #' @seealso \link{createTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -71,7 +70,6 @@ createExternalTable <- function(x, ...) { #' @return A SparkDataFrame. #' @rdname createTable #' @seealso \link{createExternalTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -110,7 +108,6 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, .. #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -140,7 +137,6 @@ cacheTable <- function(x, ...) { #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -167,7 +163,6 @@ uncacheTable <- function(x, ...) { #' Removes all cached tables from the in-memory cache. #' #' @rdname clearCache -#' @export #' @examples #' \dontrun{ #' clearCache() @@ -193,7 +188,6 @@ clearCache <- function() { #' @param tableName The name of the SparkSQL table to be dropped. #' @seealso \link{dropTempView} #' @rdname dropTempTable-deprecated -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -225,7 +219,6 @@ dropTempTable <- function(x, ...) { #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +244,6 @@ dropTempView <- function(viewName) { #' @return a SparkDataFrame #' @rdname tables #' @seealso \link{listTables} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -276,7 +268,6 @@ tables <- function(x, ...) { #' @param databaseName (optional) name of the database #' @return a list of table names #' @rdname tableNames -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -304,7 +295,6 @@ tableNames <- function(x, ...) { #' @return name of the current default database. #' @rdname currentDatabase #' @name currentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -324,7 +314,6 @@ currentDatabase <- function() { #' @param databaseName name of the database #' @rdname setCurrentDatabase #' @name setCurrentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -347,7 +336,6 @@ setCurrentDatabase <- function(databaseName) { #' @return a SparkDataFrame of the list of databases. #' @rdname listDatabases #' @name listDatabases -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -370,7 +358,6 @@ listDatabases <- function() { #' @rdname listTables #' @name listTables #' @seealso \link{tables} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -403,7 +390,6 @@ listTables <- function(databaseName = NULL) { #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns #' @name listColumns -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -433,7 +419,6 @@ listColumns <- function(tableName, databaseName = NULL) { #' @return a SparkDataFrame of the list of function descriptions. #' @rdname listFunctions #' @name listFunctions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -463,7 +448,6 @@ listFunctions <- function(databaseName = NULL) { #' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -490,7 +474,6 @@ recoverPartitions <- function(tableName) { #' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -512,7 +495,6 @@ refreshTable <- function(tableName) { #' @param path the path of the data source. #' @rdname refreshByPath #' @name refreshByPath -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 3095adb918b67..9727efc354f10 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -29,7 +29,6 @@ setOldClass("jobj") #' @rdname column #' #' @slot jc reference to JVM SparkDataFrame column -#' @export #' @note Column since 1.4.0 setClass("Column", slots = list(jc = "jobj")) @@ -56,7 +55,6 @@ setMethod("column", #' @rdname show #' @name show #' @aliases show,Column-method -#' @export #' @note show(Column) since 1.4.0 setMethod("show", "Column", function(object) { @@ -134,7 +132,6 @@ createMethods() #' @name alias #' @aliases alias,Column-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -270,7 +267,6 @@ setMethod("cast", #' @name %in% #' @aliases %in%,Column-method #' @return A matched values as a result of comparing with given values. -#' @export #' @examples #' \dontrun{ #' filter(df, "age in (10, 30)") @@ -296,7 +292,6 @@ setMethod("%in%", #' @name otherwise #' @family colum_func #' @aliases otherwise,Column-method -#' @export #' @note otherwise since 1.5.0 setMethod("otherwise", signature(x = "Column", value = "ANY"), @@ -318,7 +313,6 @@ setMethod("otherwise", #' @rdname eq_null_safe #' @name %<=>% #' @aliases %<=>%,Column-method -#' @export #' @examples #' \dontrun{ #' df1 <- createDataFrame(data.frame( @@ -348,7 +342,6 @@ setMethod("%<=>%", #' @rdname not #' @name not #' @aliases !,Column-method -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ace..8ec727dd042bc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -308,7 +308,6 @@ setCheckpointDirSC <- function(sc, dirName) { #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. -#' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") @@ -323,7 +322,6 @@ spark.addFile <- function(path, recursive = FALSE) { #' #' @rdname spark.getSparkFilesRootDirectory #' @return the root directory that contains files added through spark.addFile -#' @export #' @examples #'\dontrun{ #' spark.getSparkFilesRootDirectory() @@ -344,7 +342,6 @@ spark.getSparkFilesRootDirectory <- function() { # nolint #' @rdname spark.getSparkFiles #' @param fileName The name of the file added through spark.addFile #' @return the absolute path of a file added through spark.addFile. -#' @export #' @examples #'\dontrun{ #' spark.getSparkFiles("myfile") @@ -391,7 +388,6 @@ spark.getSparkFiles <- function(fileName) { #' @param list the list of elements #' @param func a function that takes one argument. #' @return a list of results (the exact type being determined by the function) -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -412,7 +408,6 @@ spark.lapply <- function(list, func) { #' #' @rdname setLogLevel #' @param level New log level -#' @export #' @examples #'\dontrun{ #' setLogLevel("ERROR") @@ -431,7 +426,6 @@ setLogLevel <- function(level) { #' @rdname setCheckpointDir #' @param directory Directory path to checkpoint to #' @seealso \link{checkpoint} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 29ee146ab14f9..a527426b19674 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -244,7 +244,6 @@ NULL #' If the parameter is a Column, it is returned unchanged. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases lit lit,ANY-method #' @examples #' @@ -267,7 +266,6 @@ setMethod("lit", signature("ANY"), #' \code{abs}: Computes the absolute value. #' #' @rdname column_math_functions -#' @export #' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", @@ -282,7 +280,6 @@ setMethod("abs", #' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions -#' @export #' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", @@ -296,7 +293,6 @@ setMethod("acos", #' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' #' @rdname column_aggregate_functions -#' @export #' @aliases approxCountDistinct approxCountDistinct,Column-method #' @examples #' @@ -319,7 +315,6 @@ setMethod("approxCountDistinct", #' and returns the result as an int column. #' #' @rdname column_string_functions -#' @export #' @aliases ascii ascii,Column-method #' @examples #' @@ -338,7 +333,6 @@ setMethod("ascii", #' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions -#' @export #' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", @@ -353,7 +347,6 @@ setMethod("asin", #' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions -#' @export #' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", @@ -370,7 +363,6 @@ setMethod("atan", #' @rdname avg #' @name avg #' @family aggregate functions -#' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} #' @note avg since 1.4.0 @@ -386,7 +378,6 @@ setMethod("avg", #' a string column. This is the reverse of unbase64. #' #' @rdname column_string_functions -#' @export #' @aliases base64 base64,Column-method #' @examples #' @@ -410,7 +401,6 @@ setMethod("base64", #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions -#' @export #' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", @@ -424,7 +414,6 @@ setMethod("bin", #' \code{bitwiseNOT}: Computes bitwise NOT. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases bitwiseNOT bitwiseNOT,Column-method #' @examples #' @@ -442,7 +431,6 @@ setMethod("bitwiseNOT", #' \code{cbrt}: Computes the cube-root of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", @@ -456,7 +444,6 @@ setMethod("cbrt", #' \code{ceil}: Computes the ceiling of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", @@ -471,7 +458,6 @@ setMethod("ceil", #' #' @rdname column_math_functions #' @aliases ceiling ceiling,Column-method -#' @export #' @note ceiling since 1.5.0 setMethod("ceiling", signature(x = "Column"), @@ -483,7 +469,6 @@ setMethod("ceiling", #' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases coalesce,Column-method #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", @@ -514,7 +499,6 @@ col <- function(x) { #' @rdname column #' @name column #' @family non-aggregate functions -#' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} #' @note column since 1.6.0 @@ -533,7 +517,6 @@ setMethod("column", #' @rdname corr #' @name corr #' @family aggregate functions -#' @export #' @aliases corr,Column-method #' @examples #' \dontrun{ @@ -557,7 +540,6 @@ setMethod("corr", signature(x = "Column"), #' @rdname cov #' @name cov #' @family aggregate functions -#' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ @@ -598,7 +580,6 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname cov #' @name covar_pop -#' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), @@ -618,7 +599,6 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' @rdname column_math_functions #' @aliases cos cos,Column-method -#' @export #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -633,7 +613,6 @@ setMethod("cos", #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method -#' @export #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -651,7 +630,6 @@ setMethod("cosh", #' @name count #' @family aggregate functions #' @aliases count,Column-method -#' @export #' @examples \dontrun{count(df$c)} #' @note count since 1.4.0 setMethod("count", @@ -667,7 +645,6 @@ setMethod("count", #' #' @rdname column_misc_functions #' @aliases crc32 crc32,Column-method -#' @export #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -682,7 +659,6 @@ setMethod("crc32", #' #' @rdname column_misc_functions #' @aliases hash hash,Column-method -#' @export #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -701,7 +677,6 @@ setMethod("hash", #' #' @rdname column_datetime_functions #' @aliases dayofmonth dayofmonth,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -723,7 +698,6 @@ setMethod("dayofmonth", #' #' @rdname column_datetime_functions #' @aliases dayofweek dayofweek,Column-method -#' @export #' @note dayofweek since 2.3.0 setMethod("dayofweek", signature(x = "Column"), @@ -738,7 +712,6 @@ setMethod("dayofweek", #' #' @rdname column_datetime_functions #' @aliases dayofyear dayofyear,Column-method -#' @export #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -756,7 +729,6 @@ setMethod("dayofyear", #' #' @rdname column_string_functions #' @aliases decode decode,Column,character-method -#' @export #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -771,7 +743,6 @@ setMethod("decode", #' #' @rdname column_string_functions #' @aliases encode encode,Column,character-method -#' @export #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -785,7 +756,6 @@ setMethod("encode", #' #' @rdname column_math_functions #' @aliases exp exp,Column-method -#' @export #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -799,7 +769,6 @@ setMethod("exp", #' #' @rdname column_math_functions #' @aliases expm1 expm1,Column-method -#' @export #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -813,7 +782,6 @@ setMethod("expm1", #' #' @rdname column_math_functions #' @aliases factorial factorial,Column-method -#' @export #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -836,7 +804,6 @@ setMethod("factorial", #' @name first #' @aliases first,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' first(df$c) @@ -860,7 +827,6 @@ setMethod("first", #' #' @rdname column_math_functions #' @aliases floor floor,Column-method -#' @export #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -874,7 +840,6 @@ setMethod("floor", #' #' @rdname column_math_functions #' @aliases hex hex,Column-method -#' @export #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -888,7 +853,6 @@ setMethod("hex", #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -911,7 +875,6 @@ setMethod("hour", #' #' @rdname column_string_functions #' @aliases initcap initcap,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -946,7 +909,6 @@ setMethod("isnan", #' #' @rdname column_nonaggregate_functions #' @aliases is.nan is.nan,Column-method -#' @export #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -959,7 +921,6 @@ setMethod("is.nan", #' #' @rdname column_aggregate_functions #' @aliases kurtosis kurtosis,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -988,7 +949,6 @@ setMethod("kurtosis", #' @name last #' @aliases last,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' last(df$c) @@ -1014,7 +974,6 @@ setMethod("last", #' #' @rdname column_datetime_functions #' @aliases last_day last_day,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1034,7 +993,6 @@ setMethod("last_day", #' #' @rdname column_string_functions #' @aliases length length,Column-method -#' @export #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -1048,7 +1006,6 @@ setMethod("length", #' #' @rdname column_math_functions #' @aliases log log,Column-method -#' @export #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1062,7 +1019,6 @@ setMethod("log", #' #' @rdname column_math_functions #' @aliases log10 log10,Column-method -#' @export #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1076,7 +1032,6 @@ setMethod("log10", #' #' @rdname column_math_functions #' @aliases log1p log1p,Column-method -#' @export #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1090,7 +1045,6 @@ setMethod("log1p", #' #' @rdname column_math_functions #' @aliases log2 log2,Column-method -#' @export #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1104,7 +1058,6 @@ setMethod("log2", #' #' @rdname column_string_functions #' @aliases lower lower,Column-method -#' @export #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1119,7 +1072,6 @@ setMethod("lower", #' #' @rdname column_string_functions #' @aliases ltrim ltrim,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1143,7 +1095,6 @@ setMethod("ltrim", #' @param trimString a character string to trim with #' @rdname column_string_functions #' @aliases ltrim,Column,character-method -#' @export #' @note ltrim(Column, character) since 2.3.0 setMethod("ltrim", signature(x = "Column", trimString = "character"), @@ -1171,7 +1122,6 @@ setMethod("max", #' #' @rdname column_misc_functions #' @aliases md5 md5,Column-method -#' @export #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1185,7 +1135,6 @@ setMethod("md5", #' #' @rdname column_aggregate_functions #' @aliases mean mean,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1211,7 +1160,6 @@ setMethod("mean", #' #' @rdname column_aggregate_functions #' @aliases min min,Column-method -#' @export #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1225,7 +1173,6 @@ setMethod("min", #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method -#' @export #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1248,7 +1195,6 @@ setMethod("minute", #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, monotonically_increasing_id()))} @@ -1264,7 +1210,6 @@ setMethod("monotonically_increasing_id", #' #' @rdname column_datetime_functions #' @aliases month month,Column-method -#' @export #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1278,7 +1223,6 @@ setMethod("month", #' #' @rdname column_nonaggregate_functions #' @aliases negate negate,Column-method -#' @export #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1292,7 +1236,6 @@ setMethod("negate", #' #' @rdname column_datetime_functions #' @aliases quarter quarter,Column-method -#' @export #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1306,7 +1249,6 @@ setMethod("quarter", #' #' @rdname column_string_functions #' @aliases reverse reverse,Column-method -#' @export #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1321,7 +1263,6 @@ setMethod("reverse", #' #' @rdname column_math_functions #' @aliases rint rint,Column-method -#' @export #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1336,7 +1277,6 @@ setMethod("rint", #' #' @rdname column_math_functions #' @aliases round round,Column-method -#' @export #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1356,7 +1296,6 @@ setMethod("round", #' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method -#' @export #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1371,7 +1310,6 @@ setMethod("bround", #' #' @rdname column_string_functions #' @aliases rtrim rtrim,Column,missing-method -#' @export #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column", trimString = "missing"), @@ -1382,7 +1320,6 @@ setMethod("rtrim", #' @rdname column_string_functions #' @aliases rtrim,Column,character-method -#' @export #' @note rtrim(Column, character) since 2.3.0 setMethod("rtrim", signature(x = "Column", trimString = "character"), @@ -1396,7 +1333,6 @@ setMethod("rtrim", #' #' @rdname column_aggregate_functions #' @aliases sd sd,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1414,7 +1350,6 @@ setMethod("sd", #' #' @rdname column_datetime_functions #' @aliases second second,Column-method -#' @export #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1429,7 +1364,6 @@ setMethod("second", #' #' @rdname column_misc_functions #' @aliases sha1 sha1,Column-method -#' @export #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -1443,7 +1377,6 @@ setMethod("sha1", #' #' @rdname column_math_functions #' @aliases signum signum,Column-method -#' @export #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1457,7 +1390,6 @@ setMethod("signum", #' #' @rdname column_math_functions #' @aliases sign sign,Column-method -#' @export #' @note sign since 1.5.0 setMethod("sign", signature(x = "Column"), function(x) { @@ -1470,7 +1402,6 @@ setMethod("sign", signature(x = "Column"), #' #' @rdname column_math_functions #' @aliases sin sin,Column-method -#' @export #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1485,7 +1416,6 @@ setMethod("sin", #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method -#' @export #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1499,7 +1429,6 @@ setMethod("sinh", #' #' @rdname column_aggregate_functions #' @aliases skewness skewness,Column-method -#' @export #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1513,7 +1442,6 @@ setMethod("skewness", #' #' @rdname column_string_functions #' @aliases soundex soundex,Column-method -#' @export #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1530,7 +1458,6 @@ setMethod("soundex", #' #' @rdname column_nonaggregate_functions #' @aliases spark_partition_id spark_partition_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, spark_partition_id()))} @@ -1560,7 +1487,6 @@ setMethod("stddev", #' #' @rdname column_aggregate_functions #' @aliases stddev_pop stddev_pop,Column-method -#' @export #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1574,7 +1500,6 @@ setMethod("stddev_pop", #' #' @rdname column_aggregate_functions #' @aliases stddev_samp stddev_samp,Column-method -#' @export #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1588,7 +1513,6 @@ setMethod("stddev_samp", #' #' @rdname column_nonaggregate_functions #' @aliases struct struct,characterOrColumn-method -#' @export #' @examples #' #' \dontrun{ @@ -1614,7 +1538,6 @@ setMethod("struct", #' #' @rdname column_math_functions #' @aliases sqrt sqrt,Column-method -#' @export #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1628,7 +1551,6 @@ setMethod("sqrt", #' #' @rdname column_aggregate_functions #' @aliases sum sum,Column-method -#' @export #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1642,7 +1564,6 @@ setMethod("sum", #' #' @rdname column_aggregate_functions #' @aliases sumDistinct sumDistinct,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1663,7 +1584,6 @@ setMethod("sumDistinct", #' #' @rdname column_math_functions #' @aliases tan tan,Column-method -#' @export #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1678,7 +1598,6 @@ setMethod("tan", #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method -#' @export #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1693,7 +1612,6 @@ setMethod("tanh", #' #' @rdname column_math_functions #' @aliases toDegrees toDegrees,Column-method -#' @export #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1708,7 +1626,6 @@ setMethod("toDegrees", #' #' @rdname column_math_functions #' @aliases toRadians toRadians,Column-method -#' @export #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1728,7 +1645,6 @@ setMethod("toRadians", #' #' @rdname column_datetime_functions #' @aliases to_date to_date,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1749,7 +1665,6 @@ setMethod("to_date", #' @rdname column_datetime_functions #' @aliases to_date,Column,character-method -#' @export #' @note to_date(Column, character) since 2.2.0 setMethod("to_date", signature(x = "Column", format = "character"), @@ -1765,7 +1680,6 @@ setMethod("to_date", #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1803,7 +1717,6 @@ setMethod("to_json", signature(x = "Column"), #' #' @rdname column_datetime_functions #' @aliases to_timestamp to_timestamp,Column,missing-method -#' @export #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1814,7 +1727,6 @@ setMethod("to_timestamp", #' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method -#' @export #' @note to_timestamp(Column, character) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "character"), @@ -1829,7 +1741,6 @@ setMethod("to_timestamp", #' #' @rdname column_string_functions #' @aliases trim trim,Column,missing-method -#' @export #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column", trimString = "missing"), @@ -1840,7 +1751,6 @@ setMethod("trim", #' @rdname column_string_functions #' @aliases trim,Column,character-method -#' @export #' @note trim(Column, character) since 2.3.0 setMethod("trim", signature(x = "Column", trimString = "character"), @@ -1855,7 +1765,6 @@ setMethod("trim", #' #' @rdname column_string_functions #' @aliases unbase64 unbase64,Column-method -#' @export #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1870,7 +1779,6 @@ setMethod("unbase64", #' #' @rdname column_math_functions #' @aliases unhex unhex,Column-method -#' @export #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -1884,7 +1792,6 @@ setMethod("unhex", #' #' @rdname column_string_functions #' @aliases upper upper,Column-method -#' @export #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1898,7 +1805,6 @@ setMethod("upper", #' #' @rdname column_aggregate_functions #' @aliases var var,Column-method -#' @export #' @examples #' #'\dontrun{ @@ -1913,7 +1819,6 @@ setMethod("var", #' @rdname column_aggregate_functions #' @aliases variance variance,Column-method -#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1927,7 +1832,6 @@ setMethod("variance", #' #' @rdname column_aggregate_functions #' @aliases var_pop var_pop,Column-method -#' @export #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -1941,7 +1845,6 @@ setMethod("var_pop", #' #' @rdname column_aggregate_functions #' @aliases var_samp var_samp,Column-method -#' @export #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -1955,7 +1858,6 @@ setMethod("var_samp", #' #' @rdname column_datetime_functions #' @aliases weekofyear weekofyear,Column-method -#' @export #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -1969,7 +1871,6 @@ setMethod("weekofyear", #' #' @rdname column_datetime_functions #' @aliases year year,Column-method -#' @export #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -1985,7 +1886,6 @@ setMethod("year", #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method -#' @export #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2001,7 +1901,6 @@ setMethod("atan2", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2025,7 +1924,6 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases hypot hypot,Column-method -#' @export #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2041,7 +1939,6 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname column_string_functions #' @aliases levenshtein levenshtein,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2064,7 +1961,6 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method -#' @export #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2082,7 +1978,6 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method -#' @export #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2099,7 +1994,6 @@ setMethod("nanvl", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases pmod pmod,Column-method -#' @export #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2114,7 +2008,6 @@ setMethod("pmod", signature(y = "Column"), #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method -#' @export #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2128,7 +2021,6 @@ setMethod("approxCountDistinct", #' #' @rdname column_aggregate_functions #' @aliases countDistinct countDistinct,Column-method -#' @export #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2148,7 +2040,6 @@ setMethod("countDistinct", #' #' @rdname column_string_functions #' @aliases concat concat,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2177,7 +2068,6 @@ setMethod("concat", #' #' @rdname column_nonaggregate_functions #' @aliases greatest greatest,Column-method -#' @export #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2197,7 +2087,6 @@ setMethod("greatest", #' #' @rdname column_nonaggregate_functions #' @aliases least least,Column-method -#' @export #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2216,7 +2105,6 @@ setMethod("least", #' #' @rdname column_aggregate_functions #' @aliases n_distinct n_distinct,Column-method -#' @export #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -2226,7 +2114,6 @@ setMethod("n_distinct", signature(x = "Column"), #' @rdname count #' @name n #' @aliases n,Column-method -#' @export #' @examples \dontrun{n(df$c)} #' @note n since 1.4.0 setMethod("n", signature(x = "Column"), @@ -2245,7 +2132,6 @@ setMethod("n", signature(x = "Column"), #' @rdname column_datetime_diff_functions #' #' @aliases date_format date_format,Column,character-method -#' @export #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2263,7 +2149,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method -#' @export #' @examples #' #' \dontrun{ @@ -2306,7 +2191,6 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") #' @rdname column_datetime_diff_functions #' #' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2328,7 +2212,6 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_string_functions #' @aliases instr instr,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2351,7 +2234,6 @@ setMethod("instr", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases next_day next_day,Column,character-method -#' @export #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2366,7 +2248,6 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method -#' @export #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2379,7 +2260,6 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases add_months add_months,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2400,7 +2280,6 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' #' @rdname column_datetime_diff_functions #' @aliases date_add date_add,Column,numeric-method -#' @export #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2414,7 +2293,6 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname column_datetime_diff_functions #' #' @aliases date_sub date_sub,Column,numeric-method -#' @export #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2431,7 +2309,6 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' #' @rdname column_string_functions #' @aliases format_number format_number,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2454,7 +2331,6 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @rdname column_misc_functions #' @aliases sha2 sha2,Column,numeric-method -#' @export #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2468,7 +2344,6 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftLeft shiftLeft,Column,numeric-method -#' @export #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2484,7 +2359,6 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method -#' @export #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2500,7 +2374,6 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method -#' @export #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2517,7 +2390,6 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method -#' @export #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2533,7 +2405,6 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param toBase base to convert to. #' @rdname column_math_functions #' @aliases conv conv,Column,numeric,numeric-method -#' @export #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { @@ -2551,7 +2422,6 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' #' @rdname column_nonaggregate_functions #' @aliases expr expr,character-method -#' @export #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2566,7 +2436,6 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @rdname column_string_functions #' @aliases format_string format_string,character,Column-method -#' @export #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2587,7 +2456,6 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname column_datetime_functions #' #' @aliases from_unixtime from_unixtime,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2629,7 +2497,6 @@ setMethod("from_unixtime", signature(x = "Column"), #' \code{startTime} as \code{"15 minutes"}. #' @rdname column_datetime_functions #' @aliases window window,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2680,7 +2547,6 @@ setMethod("window", signature(x = "Column"), #' @param pos start position of search. #' @rdname column_string_functions #' @aliases locate locate,character,Column-method -#' @export #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2697,7 +2563,6 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param pad a character string to be padded with. #' @rdname column_string_functions #' @aliases lpad lpad,Column,numeric,character-method -#' @export #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2714,7 +2579,6 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. #' @aliases rand rand,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -2729,7 +2593,6 @@ setMethod("rand", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method -#' @export #' @note rand(numeric) since 1.5.0 setMethod("rand", signature(seed = "numeric"), function(seed) { @@ -2743,7 +2606,6 @@ setMethod("rand", signature(seed = "numeric"), #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method -#' @export #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2753,7 +2615,6 @@ setMethod("randn", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method -#' @export #' @note randn(numeric) since 1.5.0 setMethod("randn", signature(seed = "numeric"), function(seed) { @@ -2770,7 +2631,6 @@ setMethod("randn", signature(seed = "numeric"), #' @param idx a group index. #' @rdname column_string_functions #' @aliases regexp_extract regexp_extract,Column,character,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2799,7 +2659,6 @@ setMethod("regexp_extract", #' @param replacement a character string that a matched \code{pattern} is replaced with. #' @rdname column_string_functions #' @aliases regexp_replace regexp_replace,Column,character,character-method -#' @export #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2815,7 +2674,6 @@ setMethod("regexp_replace", #' #' @rdname column_string_functions #' @aliases rpad rpad,Column,numeric,character-method -#' @export #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2838,7 +2696,6 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' counting from the right. #' @rdname column_string_functions #' @aliases substring_index substring_index,Column,character,numeric-method -#' @export #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2861,7 +2718,6 @@ setMethod("substring_index", #' at the same location, if any. #' @rdname column_string_functions #' @aliases translate translate,Column,character,character-method -#' @export #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -2876,7 +2732,6 @@ setMethod("translate", #' #' @rdname column_datetime_functions #' @aliases unix_timestamp unix_timestamp,missing,missing-method -#' @export #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -2886,7 +2741,6 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method -#' @export #' @note unix_timestamp(Column) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { @@ -2896,7 +2750,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method -#' @export #' @note unix_timestamp(Column, character) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -2912,7 +2765,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. #' @aliases when when,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2941,7 +2793,6 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. #' @aliases ifelse ifelse,Column-method -#' @export #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -2967,7 +2818,6 @@ setMethod("ifelse", #' #' @rdname column_window_functions #' @aliases cume_dist cume_dist,missing-method -#' @export #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2988,7 +2838,6 @@ setMethod("cume_dist", #' #' @rdname column_window_functions #' @aliases dense_rank dense_rank,missing-method -#' @export #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -3005,7 +2854,6 @@ setMethod("dense_rank", #' #' @rdname column_window_functions #' @aliases lag lag,characterOrColumn-method -#' @export #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -3030,7 +2878,6 @@ setMethod("lag", #' #' @rdname column_window_functions #' @aliases lead lead,characterOrColumn,numeric-method -#' @export #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -3054,7 +2901,6 @@ setMethod("lead", #' #' @rdname column_window_functions #' @aliases ntile ntile,numeric-method -#' @export #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3072,7 +2918,6 @@ setMethod("ntile", #' #' @rdname column_window_functions #' @aliases percent_rank percent_rank,missing-method -#' @export #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3093,7 +2938,6 @@ setMethod("percent_rank", #' #' @rdname column_window_functions #' @aliases rank rank,missing-method -#' @export #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3104,7 +2948,6 @@ setMethod("rank", #' @rdname column_window_functions #' @aliases rank,ANY-method -#' @export setMethod("rank", signature(x = "ANY"), function(x, ...) { @@ -3118,7 +2961,6 @@ setMethod("rank", #' #' @rdname column_window_functions #' @aliases row_number row_number,missing-method -#' @export #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3136,7 +2978,6 @@ setMethod("row_number", #' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method -#' @export #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3150,7 +2991,6 @@ setMethod("array_contains", #' #' @rdname column_collection_functions #' @aliases map_keys map_keys,Column-method -#' @export #' @note map_keys since 2.3.0 setMethod("map_keys", signature(x = "Column"), @@ -3164,7 +3004,6 @@ setMethod("map_keys", #' #' @rdname column_collection_functions #' @aliases map_values map_values,Column-method -#' @export #' @note map_values since 2.3.0 setMethod("map_values", signature(x = "Column"), @@ -3178,7 +3017,6 @@ setMethod("map_values", #' #' @rdname column_collection_functions #' @aliases explode explode,Column-method -#' @export #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3192,7 +3030,6 @@ setMethod("explode", #' #' @rdname column_collection_functions #' @aliases size size,Column-method -#' @export #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3210,7 +3047,6 @@ setMethod("size", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method -#' @export #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3225,7 +3061,6 @@ setMethod("sort_array", #' #' @rdname column_collection_functions #' @aliases posexplode posexplode,Column-method -#' @export #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3240,7 +3075,6 @@ setMethod("posexplode", #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method -#' @export #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3261,7 +3095,6 @@ setMethod("create_array", #' #' @rdname column_nonaggregate_functions #' @aliases create_map create_map,Column-method -#' @export #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3279,7 +3112,6 @@ setMethod("create_map", #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3299,7 +3131,6 @@ setMethod("collect_list", #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method -#' @export #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3314,7 +3145,6 @@ setMethod("collect_set", #' #' @rdname column_string_functions #' @aliases split_string split_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3337,7 +3167,6 @@ setMethod("split_string", #' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3360,7 +3189,6 @@ setMethod("repeat_string", #' #' @rdname column_collection_functions #' @aliases explode_outer explode_outer,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3385,7 +3213,6 @@ setMethod("explode_outer", #' #' @rdname column_collection_functions #' @aliases posexplode_outer posexplode_outer,Column-method -#' @export #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), @@ -3406,7 +3233,6 @@ setMethod("posexplode_outer", #' @name not #' @aliases not,Column-method #' @family non-aggregate functions -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -3434,7 +3260,6 @@ setMethod("not", #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3467,7 +3292,6 @@ setMethod("grouping_bit", #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3502,7 +3326,6 @@ setMethod("grouping_id", #' #' @rdname column_nonaggregate_functions #' @aliases input_file_name input_file_name,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -3520,7 +3343,6 @@ setMethod("input_file_name", signature("missing"), #' #' @rdname column_datetime_functions #' @aliases trunc trunc,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3540,7 +3362,6 @@ setMethod("trunc", #' #' @rdname column_datetime_functions #' @aliases date_trunc date_trunc,character,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3559,7 +3380,6 @@ setMethod("date_trunc", #' #' @rdname column_datetime_functions #' @aliases current_date current_date,missing-method -#' @export #' @examples #' \dontrun{ #' head(select(df, current_date(), current_timestamp()))} @@ -3576,7 +3396,6 @@ setMethod("current_date", #' #' @rdname column_datetime_functions #' @aliases current_timestamp current_timestamp,missing-method -#' @export #' @note current_timestamp since 2.3.0 setMethod("current_timestamp", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e0dde3339fabc..6fba4b6c761dd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -19,7 +19,6 @@ # @rdname aggregateRDD # @seealso reduce -# @export setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) @@ -27,21 +26,17 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition -# @export setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods -# @export setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods -# @export setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) # @rdname collect-methods -# @export setGeneric("collectPartition", function(x, partitionId) { standardGeneric("collectPartition") @@ -52,19 +47,15 @@ setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue -# @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @rdname crosstab -# @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @rdname freqItems -# @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) # @rdname approxQuantile -# @export setGeneric("approxQuantile", function(x, cols, probabilities, relativeError) { standardGeneric("approxQuantile") @@ -73,18 +64,15 @@ setGeneric("approxQuantile", setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD -# @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap -# @export setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @rdname fold # @seealso reduce -# @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) @@ -95,17 +83,14 @@ setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachParti setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) # @rdname glom -# @export setGeneric("glom", function(x) { standardGeneric("glom") }) # @rdname histogram -# @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) # @rdname keyBy -# @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) @@ -123,47 +108,37 @@ setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) # @rdname maximum -# @export setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @rdname minimum -# @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) # @rdname sumRDD -# @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @rdname name -# @export setGeneric("name", function(x) { standardGeneric("name") }) # @rdname getNumPartitionsRDD -# @export setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions -# @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD -# @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) # @rdname pivot -# @export setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) # @rdname reduce -# @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD -# @export setGeneric("sampleRDD", function(x, withReplacement, fraction, seed) { standardGeneric("sampleRDD") @@ -171,21 +146,17 @@ setGeneric("sampleRDD", # @rdname saveAsObjectFile # @seealso objectFile -# @export setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) # @rdname saveAsTextFile -# @export setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) # @rdname setName -# @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) # @rdname sortBy -# @export setGeneric("sortBy", function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") @@ -194,88 +165,71 @@ setGeneric("sortBy", setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered -# @export setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) # @rdname takeSample -# @export setGeneric("takeSample", function(x, withReplacement, num, seed) { standardGeneric("takeSample") }) # @rdname top -# @export setGeneric("top", function(x, num) { standardGeneric("top") }) # @rdname unionRDD -# @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD -# @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD -# @export setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex # @seealso zipWithUniqueId -# @export setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) # @rdname zipWithUniqueId # @seealso zipWithIndex -# @export setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) ############ Binary Functions ############# # @rdname cartesian -# @export setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) # @rdname countByKey -# @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) # @rdname flatMapValues -# @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) # @rdname intersection -# @export setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) # @rdname keys -# @export setGeneric("keys", function(x) { standardGeneric("keys") }) # @rdname lookup -# @export setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) # @rdname mapValues -# @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) # @rdname sampleByKey -# @export setGeneric("sampleByKey", function(x, withReplacement, fractions, seed) { standardGeneric("sampleByKey") }) # @rdname values -# @export setGeneric("values", function(x) { standardGeneric("values") }) @@ -283,14 +237,12 @@ setGeneric("values", function(x) { standardGeneric("values") }) # @rdname aggregateByKey # @seealso foldByKey, combineByKey -# @export setGeneric("aggregateByKey", function(x, zeroValue, seqOp, combOp, numPartitions) { standardGeneric("aggregateByKey") }) # @rdname cogroup -# @export setGeneric("cogroup", function(..., numPartitions) { standardGeneric("cogroup") @@ -299,7 +251,6 @@ setGeneric("cogroup", # @rdname combineByKey # @seealso groupByKey, reduceByKey -# @export setGeneric("combineByKey", function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { standardGeneric("combineByKey") @@ -307,64 +258,53 @@ setGeneric("combineByKey", # @rdname foldByKey # @seealso aggregateByKey, combineByKey -# @export setGeneric("foldByKey", function(x, zeroValue, func, numPartitions) { standardGeneric("foldByKey") }) # @rdname join-methods -# @export setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) # @rdname groupByKey # @seealso reduceByKey -# @export setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) # @rdname join-methods -# @export setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @rdname join-methods -# @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey -# @export setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) # @rdname reduceByKeyLocally # @seealso reduceByKey -# @export setGeneric("reduceByKeyLocally", function(x, combineFunc) { standardGeneric("reduceByKeyLocally") }) # @rdname join-methods -# @export setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) # @rdname sortByKey -# @export setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) # @rdname subtract -# @export setGeneric("subtract", function(x, other, numPartitions = 1) { standardGeneric("subtract") }) # @rdname subtractByKey -# @export setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") @@ -374,7 +314,6 @@ setGeneric("subtractByKey", ################### Broadcast Variable Methods ################# # @rdname broadcast -# @export setGeneric("value", function(bcast) { standardGeneric("value") }) @@ -384,7 +323,6 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @param ... further arguments to be passed to or from other methods. #' @return A SparkDataFrame. #' @rdname summarize -#' @export setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias @@ -399,11 +337,9 @@ setGeneric("agg", function(x, ...) { standardGeneric("agg") }) NULL #' @rdname arrange -#' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname as.data.frame -#' @export setGeneric("as.data.frame", function(x, row.names = NULL, optional = FALSE, ...) { standardGeneric("as.data.frame") @@ -411,52 +347,41 @@ setGeneric("as.data.frame", # Do not document the generic because of signature changes across R versions #' @noRd -#' @export setGeneric("attach") #' @rdname cache -#' @export setGeneric("cache", function(x) { standardGeneric("cache") }) #' @rdname checkpoint -#' @export setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce #' @param x a SparkDataFrame. #' @param ... additional argument(s). -#' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) #' @rdname collect -#' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) #' @param do.NULL currently not used. #' @param prefix currently not used. #' @rdname columns -#' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) #' @rdname columns -#' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) #' @rdname coltypes -#' @export setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) #' @rdname coltypes -#' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) #' @rdname columns -#' @export setGeneric("columns", function(x) {standardGeneric("columns") }) #' @param x a GroupedData or Column. #' @rdname count -#' @export setGeneric("count", function(x) { standardGeneric("count") }) #' @rdname cov @@ -464,7 +389,6 @@ setGeneric("count", function(x) { standardGeneric("count") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname corr @@ -472,294 +396,229 @@ setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname cov -#' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) #' @rdname cov -#' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) #' @rdname createOrReplaceTempView -#' @export setGeneric("createOrReplaceTempView", function(x, viewName) { standardGeneric("createOrReplaceTempView") }) # @rdname crossJoin -# @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) #' @rdname cube -#' @export setGeneric("cube", function(x, ...) { standardGeneric("cube") }) #' @rdname dapply -#' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @rdname dapplyCollect -#' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapply -#' @export setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapplyCollect -#' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) # @rdname getNumPartitions -# @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) #' @rdname describe -#' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname distinct -#' @export setGeneric("distinct", function(x) { standardGeneric("distinct") }) #' @rdname drop -#' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) #' @rdname dropDuplicates -#' @export setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") }) #' @rdname nafunctions -#' @export setGeneric("dropna", function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { standardGeneric("dropna") }) #' @rdname nafunctions -#' @export setGeneric("na.omit", function(object, ...) { standardGeneric("na.omit") }) #' @rdname dtypes -#' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain -#' @export #' @param x a SparkDataFrame or a StreamingQuery. #' @param extended Logical. If extended is FALSE, prints only the physical plan. #' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except -#' @export setGeneric("except", function(x, y) { standardGeneric("except") }) #' @rdname nafunctions -#' @export setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) #' @rdname filter -#' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @rdname first -#' @export setGeneric("first", function(x, ...) { standardGeneric("first") }) #' @rdname groupBy -#' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @rdname groupBy -#' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) #' @rdname hint -#' @export setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) #' @rdname insertInto -#' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) #' @rdname intersect -#' @export setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @rdname isLocal -#' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @rdname isStreaming -#' @export setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @rdname limit -#' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) #' @rdname localCheckpoint -#' @export setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) #' @rdname merge -#' @export setGeneric("merge") #' @rdname mutate -#' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname orderBy -#' @export setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) #' @rdname persist -#' @export setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) #' @rdname printSchema -#' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) #' @rdname registerTempTable-deprecated -#' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) #' @rdname rename -#' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition -#' @export setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample -#' @export setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) #' @rdname rollup -#' @export setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample -#' @export setGeneric("sample_frac", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy -#' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) #' @rdname saveAsTable -#' @export setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) -#' @export setGeneric("str") #' @rdname take -#' @export setGeneric("take", function(x, num) { standardGeneric("take") }) #' @rdname mutate -#' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) #' @rdname write.df -#' @export setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) #' @rdname write.jdbc -#' @export setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { standardGeneric("write.jdbc") }) #' @rdname write.json -#' @export setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc -#' @export setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet -#' @export setGeneric("write.parquet", function(x, path, ...) { standardGeneric("write.parquet") }) #' @rdname write.parquet -#' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) #' @rdname write.stream -#' @export setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { standardGeneric("write.stream") }) #' @rdname write.text -#' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema -#' @export setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select -#' @export setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr -#' @export setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) #' @rdname showDF -#' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname storageLevel -# @export setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) #' @rdname subset -#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname summarize -#' @export setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary -#' @export setGeneric("summary", function(object, ...) { standardGeneric("summary") }) setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -767,830 +626,660 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union -#' @export setGeneric("union", function(x, y) { standardGeneric("union") }) #' @rdname union -#' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @rdname unionByName -#' @export setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) #' @rdname unpersist -#' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @rdname filter -#' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @rdname with -#' @export setGeneric("with") #' @rdname withColumn -#' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) #' @rdname rename -#' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) #' @rdname withWatermark -#' @export setGeneric("withWatermark", function(x, eventTime, delayThreshold) { standardGeneric("withWatermark") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit -#' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) #' @rdname broadcast -#' @export setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) ###################### Column Methods ########################## #' @rdname columnfunctions -#' @export setGeneric("asc", function(x) { standardGeneric("asc") }) #' @rdname between -#' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @rdname cast -#' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @rdname columnfunctions #' @param x a Column object. #' @param ... additional argument(s). -#' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) #' @rdname columnfunctions -#' @export setGeneric("desc", function(x) { standardGeneric("desc") }) #' @rdname endsWith -#' @export setGeneric("endsWith", function(x, suffix) { standardGeneric("endsWith") }) #' @rdname columnfunctions -#' @export setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @rdname columnfunctions -#' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname columnfunctions -#' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname columnfunctions -#' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @rdname columnfunctions -#' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname columnfunctions -#' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname columnfunctions -#' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname startsWith -#' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise -#' @export setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @rdname over -#' @export setGeneric("over", function(x, window) { standardGeneric("over") }) #' @rdname eq_null_safe -#' @export setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) ###################### WindowSpec Methods ########################## #' @rdname partitionBy -#' @export setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) #' @rdname rowsBetween -#' @export setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) #' @rdname rangeBetween -#' @export setGeneric("rangeBetween", function(x, start, end) { standardGeneric("rangeBetween") }) #' @rdname windowPartitionBy -#' @export setGeneric("windowPartitionBy", function(col, ...) { standardGeneric("windowPartitionBy") }) #' @rdname windowOrderBy -#' @export setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy") }) ###################### Expression Function Methods ########################## #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. #' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg -#' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column -#' @export setGeneric("column", function(x) { standardGeneric("column") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_date", function(x = "missing") { standardGeneric("current_date") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_timestamp", function(x = "missing") { standardGeneric("current_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last -#' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count -#' @export setGeneric("n", function(x) { standardGeneric("n") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not -#' @export setGeneric("not", function(x) { standardGeneric("not") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("trim", function(x, trimString) { standardGeneric("trim") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) @@ -1598,142 +1287,110 @@ setGeneric("year", function(x) { standardGeneric("year") }) ###################### Spark.ML Methods ########################## #' @rdname fitted -#' @export setGeneric("fitted") # Do not carry stats::glm usage and param here, and do not document the generic -#' @export #' @noRd setGeneric("glm") #' @param object a fitted ML model object. #' @param ... additional argument(s) passed to the method. #' @rdname predict -#' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind -#' @export setGeneric("rbind", signature = "...") #' @rdname spark.als -#' @export setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) #' @rdname spark.bisectingKmeans -#' @export setGeneric("spark.bisectingKmeans", function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") }) #' @rdname spark.gaussianMixture -#' @export setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) #' @rdname spark.gbt -#' @export setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) #' @rdname spark.glm -#' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) #' @rdname spark.isoreg -#' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) #' @rdname spark.kmeans -#' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) #' @rdname spark.kstest -#' @export setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) #' @rdname spark.lda -#' @export setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) #' @rdname spark.logit -#' @export setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp -#' @export setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes -#' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) #' @rdname spark.decisionTree -#' @export setGeneric("spark.decisionTree", function(data, formula, ...) { standardGeneric("spark.decisionTree") }) #' @rdname spark.randomForest -#' @export setGeneric("spark.randomForest", function(data, formula, ...) { standardGeneric("spark.randomForest") }) #' @rdname spark.survreg -#' @export setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear -#' @export setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") }) #' @rdname spark.lda -#' @export setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") }) #' @rdname spark.lda -#' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml -#' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) ###################### Streaming Methods ########################## #' @rdname awaitTermination -#' @export setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive -#' @export setGeneric("isActive", function(x) { standardGeneric("isActive") }) #' @rdname lastProgress -#' @export setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) #' @rdname queryName -#' @export setGeneric("queryName", function(x) { standardGeneric("queryName") }) #' @rdname status -#' @export setGeneric("status", function(x) { standardGeneric("status") }) #' @rdname stopQuery -#' @export setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 54ef9f07d6fae..f751b952f3915 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -30,7 +30,6 @@ setOldClass("jobj") #' @seealso groupBy #' #' @param sgd A Java object reference to the backing Scala GroupedData -#' @export #' @note GroupedData since 1.4.0 setClass("GroupedData", slots = list(sgd = "jobj")) @@ -48,7 +47,6 @@ groupedData <- function(sgd) { #' @rdname show #' @aliases show,GroupedData-method -#' @export #' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "GroupedData", #' @return A SparkDataFrame. #' @rdname count #' @aliases count,GroupedData-method -#' @export #' @examples #' \dontrun{ #' count(groupBy(df, "name")) @@ -87,7 +84,6 @@ setMethod("count", #' @aliases agg,GroupedData-method #' @name agg #' @family agg_funcs -#' @export #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -150,7 +146,6 @@ methods <- c("avg", "max", "mean", "min", "sum") #' @rdname pivot #' @aliases pivot,GroupedData,character-method #' @name pivot -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -202,7 +197,6 @@ createMethods() #' @rdname gapply #' @aliases gapply,GroupedData-method #' @name gapply -#' @export #' @note gapply(GroupedData) since 2.0.0 setMethod("gapply", signature(x = "GroupedData"), @@ -216,7 +210,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @aliases gapplyCollect,GroupedData-method #' @name gapplyCollect -#' @export #' @note gapplyCollect(GroupedData) since 2.0.0 setMethod("gapplyCollect", signature(x = "GroupedData"), diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 04dc7562e5346..6d1edf6b6f3cf 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -58,7 +58,6 @@ #' @rdname install.spark #' @name install.spark #' @aliases install.spark -#' @export #' @examples #'\dontrun{ #' install.spark() diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R index bb5c77544a3da..9a1b26b0fa3c5 100644 --- a/R/pkg/R/jvm.R +++ b/R/pkg/R/jvm.R @@ -35,7 +35,6 @@ #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} #' @rdname sparkR.callJMethod #' @examples @@ -69,7 +68,6 @@ sparkR.callJMethod <- function(x, methodName, ...) { #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} #' @rdname sparkR.callJStatic #' @examples @@ -100,7 +98,6 @@ sparkR.callJStatic <- function(x, methodName, ...) { #' @param ... arguments to be passed to the constructor. #' @return the object created. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} #' @rdname sparkR.newJObject #' @examples diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index f6e9b1357561b..2964fdeff0957 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -21,28 +21,24 @@ #' S4 class that represents an LinearSVCModel #' #' @param jobj a Java object reference to the backing Scala LinearSVCModel -#' @export #' @note LinearSVCModel since 2.2.0 setClass("LinearSVCModel", representation(jobj = "jobj")) #' S4 class that represents an LogisticRegressionModel #' #' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a MultilayerPerceptronClassificationModel #' #' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export #' @note MultilayerPerceptronClassificationModel since 2.1.0 setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a NaiveBayesModel #' #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) @@ -82,7 +78,6 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @rdname spark.svmLinear #' @aliases spark.svmLinear,SparkDataFrame,formula-method #' @name spark.svmLinear -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -131,7 +126,6 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu #' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method -#' @export #' @note predict(LinearSVCModel) since 2.2.0 setMethod("predict", signature(object = "LinearSVCModel"), function(object, newData) { @@ -146,7 +140,6 @@ setMethod("predict", signature(object = "LinearSVCModel"), #' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method -#' @export #' @note summary(LinearSVCModel) since 2.2.0 setMethod("summary", signature(object = "LinearSVCModel"), function(object) { @@ -169,7 +162,6 @@ setMethod("summary", signature(object = "LinearSVCModel"), #' #' @rdname spark.svmLinear #' @aliases write.ml,LinearSVCModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.2.0 setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -257,7 +249,6 @@ function(object, path, overwrite = FALSE) { #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -374,7 +365,6 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") #' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method -#' @export #' @note summary(LogisticRegressionModel) since 2.1.0 setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { @@ -402,7 +392,6 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. #' @rdname spark.logit #' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export #' @note predict(LogisticRegressionModel) since 2.1.0 setMethod("predict", signature(object = "LogisticRegressionModel"), function(object, newData) { @@ -417,7 +406,6 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), #' #' @rdname spark.logit #' @aliases write.ml,LogisticRegressionModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -458,7 +446,6 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} -#' @export #' @examples #' \dontrun{ #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") @@ -517,7 +504,6 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), #' For \code{weights}, it is a numeric vector with length equal to the expected #' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp -#' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method #' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), @@ -538,7 +524,6 @@ setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel #' "prediction". #' @rdname spark.mlp #' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export #' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), function(object, newData) { @@ -553,7 +538,6 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel #' #' @rdname spark.mlp #' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export #' @seealso \link{write.ml} #' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", @@ -585,7 +569,6 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @aliases spark.naiveBayes,SparkDataFrame,formula-method #' @name spark.naiveBayes #' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export #' @examples #' \dontrun{ #' data <- as.data.frame(UCBAdmissions) @@ -624,7 +607,6 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form #' The list includes \code{apriori} (the label distribution) and #' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes -#' @export #' @note summary(NaiveBayesModel) since 2.0.0 setMethod("summary", signature(object = "NaiveBayesModel"), function(object) { @@ -648,7 +630,6 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named #' "prediction". #' @rdname spark.naiveBayes -#' @export #' @note predict(NaiveBayesModel) since 2.0.0 setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { @@ -662,7 +643,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.naiveBayes -#' @export #' @seealso \link{write.ml} #' @note write.ml(NaiveBayesModel, character) since 2.0.0 setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index a25bf81c6d977..900be685824da 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -20,28 +20,24 @@ #' S4 class that represents a BisectingKMeansModel #' #' @param jobj a Java object reference to the backing Scala BisectingKMeansModel -#' @export #' @note BisectingKMeansModel since 2.2.0 setClass("BisectingKMeansModel", representation(jobj = "jobj")) #' S4 class that represents a GaussianMixtureModel #' #' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export #' @note GaussianMixtureModel since 2.1.0 setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' S4 class that represents a KMeansModel #' #' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export #' @note KMeansModel since 2.0.0 setClass("KMeansModel", representation(jobj = "jobj")) #' S4 class that represents an LDAModel #' #' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export #' @note LDAModel since 2.1.0 setClass("LDAModel", representation(jobj = "jobj")) @@ -68,7 +64,6 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @rdname spark.bisectingKmeans #' @aliases spark.bisectingKmeans,SparkDataFrame,formula-method #' @name spark.bisectingKmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -117,7 +112,6 @@ setMethod("spark.bisectingKmeans", signature(data = "SparkDataFrame", formula = #' (cluster centers of the transformed data; cluster is NULL if is.loaded is TRUE), #' and \code{is.loaded} (whether the model is loaded from a saved file). #' @rdname spark.bisectingKmeans -#' @export #' @note summary(BisectingKMeansModel) since 2.2.0 setMethod("summary", signature(object = "BisectingKMeansModel"), function(object) { @@ -144,7 +138,6 @@ setMethod("summary", signature(object = "BisectingKMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a bisecting k-means model. #' @rdname spark.bisectingKmeans -#' @export #' @note predict(BisectingKMeansModel) since 2.2.0 setMethod("predict", signature(object = "BisectingKMeansModel"), function(object, newData) { @@ -160,7 +153,6 @@ setMethod("predict", signature(object = "BisectingKMeansModel"), #' or \code{"classes"} for assigned classes. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname spark.bisectingKmeans -#' @export #' @note fitted since 2.2.0 setMethod("fitted", signature(object = "BisectingKMeansModel"), function(object, method = c("centers", "classes")) { @@ -181,7 +173,6 @@ setMethod("fitted", signature(object = "BisectingKMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.bisectingKmeans -#' @export #' @note write.ml(BisectingKMeansModel, character) since 2.2.0 setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -208,7 +199,6 @@ setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "charact #' @rdname spark.gaussianMixture #' @name spark.gaussianMixture #' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +241,6 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = #' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture -#' @export #' @note summary(GaussianMixtureModel) since 2.1.0 setMethod("summary", signature(object = "GaussianMixtureModel"), function(object) { @@ -291,7 +280,6 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), #' "prediction". #' @aliases predict,GaussianMixtureModel,SparkDataFrame-method #' @rdname spark.gaussianMixture -#' @export #' @note predict(GaussianMixtureModel) since 2.1.0 setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { @@ -306,7 +294,6 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' @aliases write.ml,GaussianMixtureModel,character-method #' @rdname spark.gaussianMixture -#' @export #' @note write.ml(GaussianMixtureModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -336,7 +323,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @rdname spark.kmeans #' @aliases spark.kmeans,SparkDataFrame,formula-method #' @name spark.kmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -385,7 +371,6 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' (the actual number of cluster centers. When using initMode = "random", #' \code{clusterSize} may not equal to \code{k}). #' @rdname spark.kmeans -#' @export #' @note summary(KMeansModel) since 2.0.0 setMethod("summary", signature(object = "KMeansModel"), function(object) { @@ -413,7 +398,6 @@ setMethod("summary", signature(object = "KMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a k-means model. #' @rdname spark.kmeans -#' @export #' @note predict(KMeansModel) since 2.0.0 setMethod("predict", signature(object = "KMeansModel"), function(object, newData) { @@ -431,7 +415,6 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param ... additional argument(s) passed to the method. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted -#' @export #' @examples #' \dontrun{ #' model <- spark.kmeans(trainingData, ~ ., 2) @@ -458,7 +441,6 @@ setMethod("fitted", signature(object = "KMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.kmeans -#' @export #' @note write.ml(KMeansModel, character) since 2.0.0 setMethod("write.ml", signature(object = "KMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -496,7 +478,6 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export #' @examples #' \dontrun{ #' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") @@ -558,7 +539,6 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' It is only for distributed LDA model (i.e., optimizer = "em")} #' @rdname spark.lda #' @aliases summary,LDAModel-method -#' @export #' @note summary(LDAModel) since 2.1.0 setMethod("summary", signature(object = "LDAModel"), function(object, maxTermsPerTopic) { @@ -596,7 +576,6 @@ setMethod("summary", signature(object = "LDAModel"), #' perplexity of the training data if missing argument "data". #' @rdname spark.lda #' @aliases spark.perplexity,LDAModel-method -#' @export #' @note spark.perplexity(LDAModel) since 2.1.0 setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), function(object, data) { @@ -611,7 +590,6 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr #' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export #' @note spark.posterior(LDAModel) since 2.1.0 setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), function(object, newData) { @@ -626,7 +604,6 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' #' @rdname spark.lda #' @aliases write.ml,LDAModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(LDAModel, character) since 2.1.0 setMethod("write.ml", signature(object = "LDAModel", path = "character"), diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index dfcb45a1b66c9..e2394906d8012 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -20,7 +20,6 @@ #' S4 class that represents a FPGrowthModel #' #' @param jobj a Java object reference to the backing Scala FPGrowthModel -#' @export #' @note FPGrowthModel since 2.2.0 setClass("FPGrowthModel", slots = list(jobj = "jobj")) @@ -45,7 +44,6 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' @rdname spark.fpGrowth #' @name spark.fpGrowth #' @aliases spark.fpGrowth,SparkDataFrame-method -#' @export #' @examples #' \dontrun{ #' raw_data <- read.df( @@ -109,7 +107,6 @@ setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), #' and \code{freq} (frequency of the itemset). #' @rdname spark.fpGrowth #' @aliases freqItemsets,FPGrowthModel-method -#' @export #' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), function(object) { @@ -125,7 +122,6 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), #' and \code{condfidence} (confidence). #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method -#' @export #' @note spark.associationRules(FPGrowthModel) since 2.2.0 setMethod("spark.associationRules", signature(object = "FPGrowthModel"), function(object) { @@ -138,7 +134,6 @@ setMethod("spark.associationRules", signature(object = "FPGrowthModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.fpGrowth #' @aliases predict,FPGrowthModel-method -#' @export #' @note predict(FPGrowthModel) since 2.2.0 setMethod("predict", signature(object = "FPGrowthModel"), function(object, newData) { @@ -153,7 +148,6 @@ setMethod("predict", signature(object = "FPGrowthModel"), #' if the output path exists. #' @rdname spark.fpGrowth #' @aliases write.ml,FPGrowthModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(FPGrowthModel, character) since 2.2.0 setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index 5441c4a4022a9..9a77b07462585 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -20,7 +20,6 @@ #' S4 class that represents an ALSModel #' #' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export #' @note ALSModel since 2.1.0 setClass("ALSModel", representation(jobj = "jobj")) @@ -55,7 +54,6 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als -#' @export #' @examples #' \dontrun{ #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -118,7 +116,6 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), #' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method -#' @export #' @note summary(ALSModel) since 2.1.0 setMethod("summary", signature(object = "ALSModel"), function(object) { @@ -139,7 +136,6 @@ setMethod("summary", signature(object = "ALSModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.als #' @aliases predict,ALSModel-method -#' @export #' @note predict(ALSModel) since 2.1.0 setMethod("predict", signature(object = "ALSModel"), function(object, newData) { @@ -155,7 +151,6 @@ setMethod("predict", signature(object = "ALSModel"), #' #' @rdname spark.als #' @aliases write.ml,ALSModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(ALSModel, character) since 2.1.0 setMethod("write.ml", signature(object = "ALSModel", path = "character"), diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 545be5e1d89f0..95c1a29905197 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -21,21 +21,18 @@ #' S4 class that represents a AFTSurvivalRegressionModel #' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export #' @note AFTSurvivalRegressionModel since 2.0.0 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a generalized linear model #' #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export #' @note GeneralizedLinearRegressionModel since 2.0.0 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' S4 class that represents an IsotonicRegressionModel #' #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export #' @note IsotonicRegressionModel since 2.1.0 setClass("IsotonicRegressionModel", representation(jobj = "jobj")) @@ -85,7 +82,6 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -211,7 +207,6 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @aliases glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -244,7 +239,6 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in #' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm -#' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object) { @@ -290,7 +284,6 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), #' @rdname spark.glm #' @param x summary object of fitted generalized linear model returned by \code{summary} function. -#' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { if (x$is.loaded) { @@ -324,7 +317,6 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named #' "prediction". #' @rdname spark.glm -#' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { @@ -338,7 +330,6 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.glm -#' @export #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -363,7 +354,6 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -412,7 +402,6 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" #' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method -#' @export #' @note summary(IsotonicRegressionModel) since 2.1.0 setMethod("summary", signature(object = "IsotonicRegressionModel"), function(object) { @@ -429,7 +418,6 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export #' @note predict(IsotonicRegressionModel) since 2.1.0 setMethod("predict", signature(object = "IsotonicRegressionModel"), function(object, newData) { @@ -444,7 +432,6 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), #' #' @rdname spark.isoreg #' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export #' @note write.ml(IsotonicRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -477,7 +464,6 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(ovarian) @@ -517,7 +503,6 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' The list includes the model's \code{coefficients} (features, coefficients, #' intercept and log(scale)). #' @rdname spark.survreg -#' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), function(object) { @@ -537,7 +522,6 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values #' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg -#' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { @@ -550,7 +534,6 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @rdname spark.survreg -#' @export #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 #' @seealso \link{write.ml} setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R index 3e013f1d45e38..f8c3329359961 100644 --- a/R/pkg/R/mllib_stat.R +++ b/R/pkg/R/mllib_stat.R @@ -20,7 +20,6 @@ #' S4 class that represents an KSTest #' #' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) @@ -52,7 +51,6 @@ setClass("KSTest", representation(jobj = "jobj")) #' @name spark.kstest #' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ #' MLlib: Hypothesis Testing} -#' @export #' @examples #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) @@ -94,7 +92,6 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), #' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method -#' @export #' @note summary(KSTest) since 2.1.0 setMethod("summary", signature(object = "KSTest"), function(object) { @@ -117,7 +114,6 @@ setMethod("summary", signature(object = "KSTest"), #' @rdname spark.kstest #' @param x summary object of KSTest returned by \code{summary}. -#' @export #' @note print.summary.KSTest since 2.1.0 print.summary.KSTest <- function(x, ...) { jobj <- x$jobj diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 4e5ddf22ee16d..6769be038efa9 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -20,42 +20,36 @@ #' S4 class that represents a GBTRegressionModel #' #' @param jobj a Java object reference to the backing Scala GBTRegressionModel -#' @export #' @note GBTRegressionModel since 2.1.0 setClass("GBTRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a GBTClassificationModel #' #' @param jobj a Java object reference to the backing Scala GBTClassificationModel -#' @export #' @note GBTClassificationModel since 2.1.0 setClass("GBTClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestRegressionModel #' #' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export #' @note RandomForestRegressionModel since 2.1.0 setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestClassificationModel #' #' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeRegressionModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel -#' @export #' @note DecisionTreeRegressionModel since 2.3.0 setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeClassificationModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel -#' @export #' @note DecisionTreeClassificationModel since 2.3.0 setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) @@ -179,7 +173,6 @@ print.summary.decisionTree <- function(x) { #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. #' @rdname spark.gbt #' @name spark.gbt -#' @export #' @examples #' \dontrun{ #' # fit a Gradient Boosted Tree Regression Model @@ -261,7 +254,6 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method -#' @export #' @note summary(GBTRegressionModel) since 2.1.0 setMethod("summary", signature(object = "GBTRegressionModel"), function(object) { @@ -275,7 +267,6 @@ setMethod("summary", signature(object = "GBTRegressionModel"), #' @param x summary object of Gradient Boosted Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.gbt -#' @export #' @note print.summary.GBTRegressionModel since 2.1.0 print.summary.GBTRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -285,7 +276,6 @@ print.summary.GBTRegressionModel <- function(x, ...) { #' @rdname spark.gbt #' @aliases summary,GBTClassificationModel-method -#' @export #' @note summary(GBTClassificationModel) since 2.1.0 setMethod("summary", signature(object = "GBTClassificationModel"), function(object) { @@ -297,7 +287,6 @@ setMethod("summary", signature(object = "GBTClassificationModel"), # Prints the summary of Gradient Boosted Tree Classification Model #' @rdname spark.gbt -#' @export #' @note print.summary.GBTClassificationModel since 2.1.0 print.summary.GBTClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -310,7 +299,6 @@ print.summary.GBTClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.gbt #' @aliases predict,GBTRegressionModel-method -#' @export #' @note predict(GBTRegressionModel) since 2.1.0 setMethod("predict", signature(object = "GBTRegressionModel"), function(object, newData) { @@ -319,7 +307,6 @@ setMethod("predict", signature(object = "GBTRegressionModel"), #' @rdname spark.gbt #' @aliases predict,GBTClassificationModel-method -#' @export #' @note predict(GBTClassificationModel) since 2.1.0 setMethod("predict", signature(object = "GBTClassificationModel"), function(object, newData) { @@ -334,7 +321,6 @@ setMethod("predict", signature(object = "GBTClassificationModel"), #' which means throw exception if the output path exists. #' @aliases write.ml,GBTRegressionModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -343,7 +329,6 @@ setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character #' @aliases write.ml,GBTClassificationModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -402,7 +387,6 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @return \code{spark.randomForest} returns a fitted Random Forest model. #' @rdname spark.randomForest #' @name spark.randomForest -#' @export #' @examples #' \dontrun{ #' # fit a Random Forest Regression Model @@ -480,7 +464,6 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method -#' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { @@ -494,7 +477,6 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -504,7 +486,6 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method -#' @export #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { @@ -516,7 +497,6 @@ setMethod("summary", signature(object = "RandomForestClassificationModel"), # Prints the summary of Random Forest Classification Model #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -529,7 +509,6 @@ print.summary.RandomForestClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method -#' @export #' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { @@ -538,7 +517,6 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method -#' @export #' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { @@ -554,7 +532,6 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), #' #' @aliases write.ml,RandomForestRegressionModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -563,7 +540,6 @@ setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = " #' @aliases write.ml,RandomForestClassificationModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -617,7 +593,6 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. #' @rdname spark.decisionTree #' @name spark.decisionTree -#' @export #' @examples #' \dontrun{ #' # fit a Decision Tree Regression Model @@ -690,7 +665,6 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method -#' @export #' @note summary(DecisionTreeRegressionModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeRegressionModel"), function(object) { @@ -704,7 +678,6 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"), #' @param x summary object of Decision Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeRegressionModel since 2.3.0 print.summary.DecisionTreeRegressionModel <- function(x, ...) { print.summary.decisionTree(x) @@ -714,7 +687,6 @@ print.summary.DecisionTreeRegressionModel <- function(x, ...) { #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeClassificationModel-method -#' @export #' @note summary(DecisionTreeClassificationModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeClassificationModel"), function(object) { @@ -726,7 +698,6 @@ setMethod("summary", signature(object = "DecisionTreeClassificationModel"), # Prints the summary of Decision Tree Classification Model #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeClassificationModel since 2.3.0 print.summary.DecisionTreeClassificationModel <- function(x, ...) { print.summary.decisionTree(x) @@ -739,7 +710,6 @@ print.summary.DecisionTreeClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeRegressionModel-method -#' @export #' @note predict(DecisionTreeRegressionModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeRegressionModel"), function(object, newData) { @@ -748,7 +718,6 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"), #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeClassificationModel-method -#' @export #' @note predict(DecisionTreeClassificationModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeClassificationModel"), function(object, newData) { @@ -764,7 +733,6 @@ setMethod("predict", signature(object = "DecisionTreeClassificationModel"), #' #' @aliases write.ml,DecisionTreeRegressionModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -773,7 +741,6 @@ setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = " #' @aliases write.ml,DecisionTreeClassificationModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..7d04bffcba3a4 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -31,7 +31,6 @@ #' MLlib model below. #' @rdname write.ml #' @name write.ml -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -48,7 +47,6 @@ NULL #' MLlib model below. #' @rdname predict #' @name predict -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -75,7 +73,6 @@ predict_internal <- function(object, newData) { #' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml -#' @export #' @seealso \link{write.ml} #' @examples #' \dontrun{ diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 65f418740c643..9831fc3cc6d01 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -29,7 +29,6 @@ #' @param ... additional structField objects #' @return a structType object #' @rdname structType -#' @export #' @examples #'\dontrun{ #' schema <- structType(structField("a", "integer"), structField("c", "string"), @@ -49,7 +48,6 @@ structType <- function(x, ...) { #' @rdname structType #' @method structType jobj -#' @export structType.jobj <- function(x, ...) { obj <- structure(list(), class = "structType") obj$jobj <- x @@ -59,7 +57,6 @@ structType.jobj <- function(x, ...) { #' @rdname structType #' @method structType structField -#' @export structType.structField <- function(x, ...) { fields <- list(x, ...) if (!all(sapply(fields, inherits, "structField"))) { @@ -76,7 +73,6 @@ structType.structField <- function(x, ...) { #' @rdname structType #' @method structType character -#' @export structType.character <- function(x, ...) { if (!is.character(x)) { stop("schema must be a DDL-formatted string.") @@ -119,7 +115,6 @@ print.structType <- function(x, ...) { #' @param ... additional argument(s) passed to the method. #' @return A structField object. #' @rdname structField -#' @export #' @examples #'\dontrun{ #' field1 <- structField("a", "integer") @@ -137,7 +132,6 @@ structField <- function(x, ...) { #' @rdname structField #' @method structField jobj -#' @export structField.jobj <- function(x, ...) { obj <- structure(list(), class = "structField") obj$jobj <- x @@ -212,7 +206,6 @@ checkType <- function(type) { #' @param type The data type of the field #' @param nullable A logical vector indicating whether or not the field is nullable #' @rdname structField -#' @export structField.character <- function(x, type, nullable = TRUE, ...) { if (class(x) != "character") { stop("Field name must be a string.") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 965471f3b07a0..a480ac606f10d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -35,7 +35,6 @@ connExists <- function(env) { #' Also terminates the backend this R session is connected to. #' @rdname sparkR.session.stop #' @name sparkR.session.stop -#' @export #' @note sparkR.session.stop since 2.0.0 sparkR.session.stop <- function() { env <- .sparkREnv @@ -84,7 +83,6 @@ sparkR.session.stop <- function() { #' @rdname sparkR.session.stop #' @name sparkR.stop -#' @export #' @note sparkR.stop since 1.4.0 sparkR.stop <- function() { sparkR.session.stop() @@ -103,7 +101,6 @@ sparkR.stop <- function() { #' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") @@ -270,7 +267,6 @@ sparkR.sparkContext <- function( #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRSQL.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -298,7 +294,6 @@ sparkRSQL.init <- function(jsc = NULL) { #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRHive.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -347,7 +342,6 @@ sparkRHive.init <- function(jsc = NULL) { #' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session #' @param ... named Spark properties passed to the method. -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -442,7 +436,6 @@ sparkR.session <- function( #' @return the SparkUI URL, or NA if it is disabled, or not started. #' @rdname sparkR.uiWebUrl #' @name sparkR.uiWebUrl -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index c8af798830b30..497f18c763048 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -37,7 +37,6 @@ setOldClass("jobj") #' @name crosstab #' @aliases crosstab,SparkDataFrame,character,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -63,7 +62,6 @@ setMethod("crosstab", #' @rdname cov #' @aliases cov,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -92,7 +90,6 @@ setMethod("cov", #' @name corr #' @aliases corr,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -124,7 +121,6 @@ setMethod("corr", #' @name freqItems #' @aliases freqItems,SparkDataFrame,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -168,7 +164,6 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @name approxQuantile #' @aliases approxQuantile,SparkDataFrame,character,numeric,numeric-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -205,7 +200,6 @@ setMethod("approxQuantile", #' @aliases sampleBy,SparkDataFrame,character,list,numeric-method #' @name sampleBy #' @family stat functions -#' @export #' @examples #'\dontrun{ #' df <- read.json("/path/to/file.json") diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index 8390bd5e6de72..fc83463f72cd4 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{read.stream} #' #' @param ssq A Java object reference to the backing Scala StreamingQuery -#' @export #' @note StreamingQuery since 2.2.0 #' @note experimental setClass("StreamingQuery", @@ -45,7 +44,6 @@ streamingQuery <- function(ssq) { } #' @rdname show -#' @export #' @note show(StreamingQuery) since 2.2.0 setMethod("show", "StreamingQuery", function(object) { @@ -70,7 +68,6 @@ setMethod("show", "StreamingQuery", #' @aliases queryName,StreamingQuery-method #' @family StreamingQuery methods #' @seealso \link{write.stream} -#' @export #' @examples #' \dontrun{ queryName(sq) } #' @note queryName(StreamingQuery) since 2.2.0 @@ -85,7 +82,6 @@ setMethod("queryName", #' @name explain #' @aliases explain,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ explain(sq) } #' @note explain(StreamingQuery) since 2.2.0 @@ -104,7 +100,6 @@ setMethod("explain", #' @name lastProgress #' @aliases lastProgress,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ lastProgress(sq) } #' @note lastProgress(StreamingQuery) since 2.2.0 @@ -129,7 +124,6 @@ setMethod("lastProgress", #' @name status #' @aliases status,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ status(sq) } #' @note status(StreamingQuery) since 2.2.0 @@ -150,7 +144,6 @@ setMethod("status", #' @name isActive #' @aliases isActive,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ isActive(sq) } #' @note isActive(StreamingQuery) since 2.2.0 @@ -177,7 +170,6 @@ setMethod("isActive", #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ awaitTermination(sq, 10000) } #' @note awaitTermination(StreamingQuery) since 2.2.0 @@ -202,7 +194,6 @@ setMethod("awaitTermination", #' @name stopQuery #' @aliases stopQuery,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ stopQuery(sq) } #' @note stopQuery(StreamingQuery) since 2.2.0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 164cd6d01a347..f1b5ecaa017df 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -108,7 +108,6 @@ isRDD <- function(name, env) { #' #' @param key the object to be hashed #' @return the hash code as an integer -#' @export #' @examples #'\dontrun{ #' hashCode(1L) # 1 diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R index 0799d841e5dc9..396b27bee80c6 100644 --- a/R/pkg/R/window.R +++ b/R/pkg/R/window.R @@ -29,7 +29,6 @@ #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") @@ -52,7 +51,6 @@ setMethod("windowPartitionBy", #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,Column-method -#' @export #' @note windowPartitionBy(Column) since 2.0.0 setMethod("windowPartitionBy", signature(col = "Column"), @@ -78,7 +76,6 @@ setMethod("windowPartitionBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- windowOrderBy("key1", "key2") @@ -101,7 +98,6 @@ setMethod("windowOrderBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,Column-method -#' @export #' @note windowOrderBy(Column) since 2.0.0 setMethod("windowOrderBy", signature(col = "Column"), From 98a5c0a35f0a24730f5074522939acf57ef95422 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 5 Mar 2018 10:50:00 -0800 Subject: [PATCH 120/593] [SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification ## What changes were proposed in this pull request? adding Structured Streaming tests for all Models/Transformers in spark.ml.classification ## How was this patch tested? N/A Author: WeichenXu Closes #20121 from WeichenXu123/ml_stream_test_classification. --- .../DecisionTreeClassifierSuite.scala | 29 +-- .../classification/GBTClassifierSuite.scala | 77 ++---- .../ml/classification/LinearSVCSuite.scala | 15 +- .../LogisticRegressionSuite.scala | 229 +++++++----------- .../MultilayerPerceptronClassifierSuite.scala | 44 ++-- .../ml/classification/NaiveBayesSuite.scala | 47 ++-- .../ml/classification/OneVsRestSuite.scala | 21 +- .../ProbabilisticClassifierSuite.scala | 29 +-- .../RandomForestClassifierSuite.scala | 16 +- 9 files changed, 202 insertions(+), 305 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 38b265d62611b..eeb0324187c5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs import testImplicits._ @@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite MLTestingUtils.checkCopyAndUids(dt, newTree) - val predictions = newTree.transform(newData) - .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => - assert(pred === rawPred.argmax, - s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") - val sum = rawPred.toArray.sum - assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, - "probability prediction mismatch") + testTransformer[(Vector, Double)](newData, newTree, + "prediction", "rawPrediction", "probability") { + case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, DecisionTreeClassificationModel](newTree, newData) + Vector, DecisionTreeClassificationModel](this, newTree, newData) } test("training with 1-category categorical feature") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 978f89c459f0a..092b4a01d5b0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -26,13 +26,12 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils @@ -40,8 +39,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import GBTClassifierSuite.compareAPIs @@ -126,14 +124,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // should predict all zeros binaryModel.setThresholds(Array(0.0, 1.0)) - val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } // should predict all ones binaryModel.setThresholds(Array(1.0, 0.0)) - val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 1.0 + } val gbtBase = new GBTClassifier val model = gbtBase.fit(df) @@ -141,15 +140,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // constant threshold scaling is the same as no thresholds binaryModel.setThresholds(Array(1.0, 1.0)) - val scaledPredictions = binaryModel.transform(df).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") { + scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) - val predictionsWithPredict = model.transform(df).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, model, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } } test("GBTClassifier: Predictor, Classifier methods") { @@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val blas = BLAS.getInstance() val validationDataset = validationData.toDF(labelCol, featuresCol) - val results = gbtModel.transform(validationDataset) - // check that raw prediction is tree predictions dot tree weights - results.select(rawPredictionCol, featuresCol).collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](validationDataset, gbtModel, + "rawPrediction", "features", "probability", "prediction") { + case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) => assert(raw.size === 2) + // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) - } - // Compare rawPrediction with probability - results.select(rawPredictionCol, probabilityCol).collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 2) + // Compare rawPrediction with probability assert(prob.size === 2) // Note: we should check other loss types for classification if they are added val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) assert(prob(0) ~== predFromRaw(0) relTol eps) assert(prob(1) ~== predFromRaw(1) relTol eps) assert(prob(0) + prob(1) ~== 1.0 absTol absEps) - } - // Compare prediction with probability - results.select(predictionCol, probabilityCol).collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") - val resultsUsingRaw2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) - val resultsUsingProb2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - gbtModel.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, GBTClassificationModel](gbtModel, validationDataset) + Vector, GBTClassificationModel](this, gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 41a5d22dd6283..a93825b8a812d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -21,20 +21,18 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.HingeAggregator import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.udf -class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearSVCSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau threshold: Double, expected: Set[(Int, Double)]): Unit = { model.setThreshold(threshold) - val results = model.transform(df).select("id", "prediction").collect() - .map(r => (r.getInt(0), r.getDouble(1))) - .toSet - assert(results === expected, s"Failed for threshold = $threshold") + testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") { + rows: Seq[Row] => + val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet + assert(results === expected, s"Failed for threshold = $threshold") + } } def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a5f81a38face9..9987cbf6ba116 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -22,22 +22,20 @@ import scala.language.existentials import scala.util.Random import scala.util.control.Breaks._ -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.LogisticAggregator import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType -class LogisticRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -332,15 +330,14 @@ class LogisticRegressionSuite val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) - val binaryZeroPredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } binaryModel.setThreshold(0.0) - val binaryOnePredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(smallMultinomialDataset) @@ -348,31 +345,36 @@ class LogisticRegressionSuite // should predict all zeros model.setThresholds(Array(1, 1000, 1000)) - val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } // should predict all ones model.setThresholds(Array(1000, 1, 1000)) - val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(onePredictions.forall(_.getDouble(0) === 1.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } // should predict all twos model.setThresholds(Array(1000, 1000, 1)) - val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(twoPredictions.forall(_.getDouble(0) === 2.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 2.0) + } // constant threshold scaling is the same as no thresholds model.setThresholds(Array(1000, 1000, 1000)) - val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1)) - val predictionsWithPredict = - model.transform(smallMultinomialDataset).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -403,21 +405,19 @@ class LogisticRegressionSuite // Modify model params, and check that the params worked. model.setThreshold(1.0) - val predAllZero = model.transform(smallBinaryDataset) - .select("prediction", "myProbability") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predAllZero.forall(_ === 0), - s"With threshold=1.0, expected predictions to be all 0, but only" + - s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model, "prediction", "myProbability") { rows => + val predAllZero = rows.map(_.getDouble(0)) + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + } // Call transform with params, and check that the params worked. - val predNotAllZero = - model.transform(smallBinaryDataset, model.threshold -> 0.0, - model.probabilityCol -> "myProb") - .select("prediction", "myProb") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predNotAllZero.exists(_ !== 0.0)) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model.copy(ParamMap(model.threshold -> 0.0, + model.probabilityCol -> "myProb")), "prediction", "myProb") { + rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0)) + } // Call fit() with new params, and check as many params as we can. lr.setThresholds(Array(0.6, 0.4)) @@ -441,10 +441,10 @@ class LogisticRegressionSuite val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallMultinomialDataset) - // check that raw prediction is coefficients dot features + intercept - results.select("rawPrediction", "features").collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), + model, "rawPrediction", "features", "probability") { + case Row(raw: Vector, features: Vector, prob: Vector) => + // check that raw prediction is coefficients dot features + intercept assert(raw.size === 3) val margins = Array.tabulate(3) { k => var margin = 0.0 @@ -455,12 +455,7 @@ class LogisticRegressionSuite margin } assert(raw ~== Vectors.dense(margins) relTol eps) - } - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 3) + // Compare rawPrediction with probability assert(prob.size === 3) val max = raw.toArray.max val subtract = if (max > 0) max else 0.0 @@ -472,39 +467,8 @@ class LogisticRegressionSuite assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps) } - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => - val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 - assert(pred == predFromProb) - } - - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallMultinomialDataset) + Vector, LogisticRegressionModel](this, model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -517,51 +481,22 @@ class LogisticRegressionSuite val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallBinaryDataset) - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), + model, "rawPrediction", "probability", "prediction") { + case Row(raw: Vector, prob: Vector, pred: Double) => + // Compare rawPrediction with probability assert(raw.size === 2) assert(prob.size === 2) val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) assert(prob(1) ~== probFromRaw1 relTol eps) assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) - } - - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallBinaryDataset) + Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } test("coefficients and intercept methods") { @@ -616,19 +551,21 @@ class LogisticRegressionSuite LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) ).toDF() - val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() - - // probabilities are correct when margins have to be adjusted - val raw1 = results(0).getAs[Vector](0) - val prob1 = results(0).getAs[Vector](1) - assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) - assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) - - // probabilities are correct when margins don't have to be adjusted - val raw2 = results(1).getAs[Vector](0) - val prob2 = results(1).getAs[Vector](1) - assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) - assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + + testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(), + model, "rawPrediction", "probability") { results: Seq[Row] => + // probabilities are correct when margins have to be adjusted + val raw1 = results(0).getAs[Vector](0) + val prob1 = results(0).getAs[Vector](1) + assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) + assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) + + // probabilities are correct when margins don't have to be adjusted + val raw2 = results(1).getAs[Vector](0) + val prob2 = results(1).getAs[Vector](1) + assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) + assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + } } test("MultiClassSummarizer") { @@ -2567,10 +2504,13 @@ class LogisticRegressionSuite val model1 = lr.fit(smallBinaryDataset) val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial") val model2 = lr2.fit(smallBinaryDataset) - val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect() - val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect() - predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect() + .map(_.getDouble(0)) + for (model <- Seq(model1, model2)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === binaryExpected + } } assert(model2.summary.totalIterations === 1) @@ -2579,10 +2519,13 @@ class LogisticRegressionSuite val lr4 = new LogisticRegression() .setInitialModel(model3).setMaxIter(5).setFamily("multinomial") val model4 = lr4.fit(smallMultinomialDataset) - val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect() - val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect() - predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction") + .collect().map(_.getDouble(0)) + for (model <- Seq(model3, model4)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === multinomialExpected + } } assert(model4.summary.totalIterations === 1) } @@ -2638,8 +2581,8 @@ class LogisticRegressionSuite LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) - val results = model.transform(constantData) - results.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantData, model, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0))) @@ -2653,8 +2596,8 @@ class LogisticRegressionSuite LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) - val resultsZero = modelZeroLabel.transform(constantZeroData) - resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) @@ -2666,8 +2609,8 @@ class LogisticRegressionSuite val constantDataWithMetadata = constantData .select(constantData("label").as("label", labelMeta), constantData("features")) val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata) - val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata) - resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index d3141ec708560..daa58a56896d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions._ -class MultilayerPerceptronClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -75,11 +70,9 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - val result = model.transform(dataset) MLTestingUtils.checkCopyAndUids(trainer, model) - val predictionAndLabels = result.select("prediction", "label").collect() - predictionAndLabels.foreach { case Row(p: Double, l: Double) => - assert(p == l) + testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") { + case Row(p: Double, l: Double) => assert(p == l) } } @@ -99,13 +92,12 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(strongDataset) - val result = model.transform(strongDataset) - result.select("probability", "expectedProbability").collect().foreach { - case Row(p: Vector, e: Vector) => - assert(p ~== e absTol 1e-3) + testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model, + "probability", "expectedProbability") { + case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, MultilayerPerceptronClassificationModel](model, strongDataset) + Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset) } test("test model probability") { @@ -118,11 +110,10 @@ class MultilayerPerceptronClassifierSuite .setSolver("l-bfgs") val model = trainer.fit(dataset) model.setProbabilityCol("probability") - val result = model.transform(dataset) - val features2prob = udf { features: Vector => model.mlpModel.predict(features) } - result.select(features2prob(col("features")), col("probability")).collect().foreach { - case Row(p1: Vector, p2: Vector) => - assert(p1 ~== p2 absTol 1e-3) + testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") { + case Row(features: Vector, prob: Vector) => + val prob2 = model.mlpModel.predict(features) + assert(prob ~== prob2 absTol 1e-3) } } @@ -175,9 +166,6 @@ class MultilayerPerceptronClassifierSuite val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { - case Row(p: Double, l: Double) => (p, l) - } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) @@ -189,8 +177,12 @@ class MultilayerPerceptronClassifierSuite lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) - val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) - assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") { + rows: Seq[Row] => + val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1))) + val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels)) + assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + } } test("read/write: MultilayerPerceptronClassifier") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 0d3adf993383f..49115c8a4db30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF() } - def validatePrediction(predictionAndLabels: DataFrame): Unit = { - val numOfErrorPredictions = predictionAndLabels.collect().count { + def validatePrediction(predictionAndLabels: Seq[Row]): Unit = { + val numOfErrorPredictions = predictionAndLabels.filter { case Row(prediction: Double, label: Double) => prediction != label - } + }.length // At least 80% of the predictions should be on. - assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + assert(numOfErrorPredictions < predictionAndLabels.length / 5) } def validateModelFit( @@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } def validateProbabilities( - featureAndProbabilities: DataFrame, + featureAndProbabilities: Seq[Row], model: NaiveBayesModel, modelType: String): Unit = { - featureAndProbabilities.collect().foreach { + featureAndProbabilities.foreach { case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { @@ -154,15 +153,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "multinomial") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "multinomial") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("Naive Bayes with weighted samples") { @@ -210,15 +212,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "bernoulli") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "bernoulli") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 25bad59b9c9cf..11e88367108b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneVsRestSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -85,10 +83,6 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset.select("prediction", "label").rdd.map { - row => (row.getDouble(0), row.getDouble(1)) - } - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) @@ -97,8 +91,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // determine the #confusion matrix in each class. // bound how much error we allow compared to multinomial logistic regression. val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel, + "prediction", "label") { rows => + val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) } + val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults)) + assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + } } test("one-vs-rest: tuning parallelism does not change output") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index d649ceac949c4..1c8c9829f18d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} @@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite { def testPredictMethods[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]( - model: M, testData: Dataset[_]): Unit = { + mlTest: MLTest, model: M, testData: Dataset[_]): Unit = { val allColModel = model.copy(ParamMap.empty) .setRawPredictionCol("rawPredictionAll") .setProbabilityCol("probabilityAll") .setPredictionCol("predictionAll") - val allColResult = allColModel.transform(testData) + + val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol)) + .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll") for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { for (probabilityCol <- Seq("", "probabilitySingle")) { @@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite { .setProbabilityCol(probabilityCol) .setPredictionCol(predictionCol) - val result = newModel.transform(allColResult) - - import org.apache.spark.sql.functions._ - - val resultRawPredictionCol = - if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) - val resultProbabilityCol = - if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) - val resultPredictionCol = - if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + import allColResult.sparkSession.implicits._ - result.select( - resultRawPredictionCol, col("rawPredictionAll"), - resultProbabilityCol, col("probabilityAll"), - resultPredictionCol, col("predictionAll") - ).collect().foreach { + mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel, + if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol, + "rawPredictionAll", + if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll", + if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll" + ) { case Row( rawPredictionSingle: Vector, rawPredictionAll: Vector, probabilitySingle: Vector, probabilityAll: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2cca2e6c04698..02a9d5c2a18c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -23,11 +23,10 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs import testImplicits._ @@ -143,11 +141,8 @@ class RandomForestClassifierSuite MLTestingUtils.checkCopyAndUids(rf, model) - val predictions = model.transform(df) - .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", + "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") val sum = rawPred.toArray.sum @@ -155,8 +150,9 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ - Vector, RandomForestClassificationModel](model, df) + Vector, RandomForestClassificationModel](this, model, df) } test("Fitting without numClasses in metadata") { From ba622f45caa808a9320c1f7ba4a4f344365dcf90 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 5 Mar 2018 20:43:03 +0100 Subject: [PATCH 121/593] [SPARK-23585][SQL] Add interpreted execution to UnwrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to UnwrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20736 from mgaido91/SPARK-23586. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++++-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 80618af1e859f..03cc8eaceb4e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -382,8 +382,14 @@ case class UnwrapOption( override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val inputObject = child.eval(input) + if (inputObject == null) { + null + } else { + inputObject.asInstanceOf[Option[_]].orNull + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 3edcc02f15264..d95db5867b19c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -66,4 +66,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvalutionWithUnsafeProjection( mapEncoder.serializer.head, mapExpected, mapInputRow) } + + test("SPARK-23585: UnwrapOption should support interpreted execution") { + val cls = classOf[Option[Int]] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val unwrapObject = UnwrapOption(IntegerType, inputObject) + Seq((Some(1), 1), (None, null), (null, null)).foreach { case (input, expected) => + checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From b0f422c3861a5a3831e481b8ffac08f6fa085d00 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 5 Mar 2018 13:23:01 -0800 Subject: [PATCH 122/593] [SPARK-23559][SS] Add epoch ID to DataWriterFactory. ## What changes were proposed in this pull request? Add an epoch ID argument to DataWriterFactory for use in streaming. As a side effect of passing in this value, DataWriter will now have a consistent lifecycle; commit() or abort() ends the lifecycle of a DataWriter instance in any execution mode. I considered making a separate streaming interface and adding the epoch ID only to that one, but I think it requires a lot of extra work for no real gain. I think it makes sense to define epoch 0 as the one and only epoch of a non-streaming query. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #20710 from jose-torres/api2. --- .../sql/kafka010/KafkaStreamWriter.scala | 5 +++- .../sql/sources/v2/writer/DataWriter.java | 12 ++++++--- .../sources/v2/writer/DataWriterFactory.java | 5 +++- .../v2/writer/streaming/StreamWriter.java | 19 +++++++------- .../datasources/v2/WriteToDataSourceV2.scala | 25 +++++++++++++------ .../streaming/MicroBatchExecution.scala | 7 ++++++ .../sources/PackedRowWriterFactory.scala | 5 +++- .../streaming/sources/memoryV2.scala | 5 +++- .../sources/v2/SimpleWritableDataSource.scala | 10 ++++++-- 9 files changed, 65 insertions(+), 28 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9307bfc001c03..ae5b5c52d514e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -65,7 +65,10 @@ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 53941a89ba94e..39bf458298862 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -31,13 +31,17 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle + * is over and Spark will not use it again. + * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an - * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. + * exception will be sent to the driver side, and Spark may retry this writing task a few times. + * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a + * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index ea95442511ce5..c2c2ab73257e8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -48,6 +48,9 @@ public interface DataWriterFactory extends Serializable { * same task id but different attempt number, which means there are multiple * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 4913341bd505d..a316b2a4c1d82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -23,11 +23,10 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and - * aborts relative to an epoch ID determined by the execution engine. + * A {@link DataSourceWriter} for use with structured streaming. * - * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, - * and so must reset any internal state after a successful commit. + * Streaming queries are divided into intervals of data called epochs, with a monotonically + * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving public interface StreamWriter extends DataSourceWriter { @@ -39,21 +38,21 @@ public interface StreamWriter extends DataSourceWriter { * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. * - * To support exactly-once processing, writer implementations should ensure that this method is - * idempotent. The execution engine may call commit() multiple times for the same epoch - * in some circumstances. + * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * To support exactly-once data semantics, implementations must ensure that multiple commits for + * the same epoch are idempotent. */ void commit(long epochId, WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or + * Aborts this writing job because some data writers are failed and keep failing when retried, or * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. * - * Unless the abort is triggered by the failure of commit, the given messages should have some - * null slots as there maybe only a few data writers that are committed before the abort + * Unless the abort is triggered by the failure of commit, the given messages will have some + * null slots, as there may be only a few data writers that were committed before the abort * happens, or some data writers were committed but their commit messages haven't reached the * driver when the abort is triggered. So this is just a "best effort" for data sources to * clean up the data left by data writers. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 41cdfc80d8a19..e80b44c1cdc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -132,7 +132,8 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val partId = context.partitionId() val attemptId = context.attemptNumber() - val dataWriter = writeTask.createDataWriter(partId, attemptId) + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") + val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -172,7 +173,6 @@ object DataWritingSparkTask extends Logging { writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) @@ -180,10 +180,15 @@ object DataWritingSparkTask extends Logging { var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong do { + var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { - iter.foreach(dataWriter.write) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } logInfo(s"Writer for partition ${context.partitionId()} is committing.") val msg = dataWriter.commit() logInfo(s"Writer for partition ${context.partitionId()} committed.") @@ -196,9 +201,10 @@ object DataWritingSparkTask extends Logging { // Continuous shutdown always involves an interrupt. Just finish the task. } })(catchBlock = { - // If there is an error, abort this writer + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. logError(s"Writer for partition ${context.partitionId()} is aborting.") - dataWriter.abort() + if (dataWriter != null) dataWriter.abort() logError(s"Writer for partition ${context.partitionId()} aborted.") }) } while (!context.isInterrupted()) @@ -211,9 +217,12 @@ class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber), + rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6bd03972c301d..ff4be9c7ab874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -469,6 +469,9 @@ class MicroBatchExecution( case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } + sparkSessionToRunBatch.sparkContext.setLocalProperty( + MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -518,3 +521,7 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object MicroBatchExecution { + val BATCH_ID_KEY = "streaming.sql.batchId" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 248295e401a0d..e07355aa37dba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * for production-quality sinks. It's intended for use in tests. */ case object PackedRowWriterFactory extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f960208155e3b..5f58246083bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -147,7 +147,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 36dd2a350a055..a5007fa321359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -207,7 +207,10 @@ private[v2] object SimpleCounter { class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -240,7 +243,10 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) From f2cab56ca22ed5db5ff604cd78cdb55aaa58f651 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 5 Mar 2018 14:57:32 -0800 Subject: [PATCH 123/593] [SPARK-23040][CORE] Returns interruptible iterator for shuffle reader ## What changes were proposed in this pull request? Before this commit, a non-interruptible iterator is returned if aggregator or ordering is specified. This commit also ensures that sorter is closed even when task is cancelled(killed) in the middle of sorting. ## How was this patch tested? Add a unit test in JobCancellationSuite Author: Xianjin YE Closes #20449 from advancedxy/SPARK-23040. --- .../shuffle/BlockStoreShuffleReader.scala | 9 ++- .../apache/spark/JobCancellationSuite.scala | 65 ++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index edd69715c9602..85e7e56a04a7d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -94,7 +94,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Sort the output if there is a sort ordering defined. - dep.keyOrdering match { + val resultIter = dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. val sorter = @@ -103,9 +103,16 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener(_ => { + sorter.stop() + }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } + // Use another interruptible iterator here to support task cancellation as aggregator or(and) + // sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 8a77aea75a992..3b793bb231cf3 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -26,7 +27,7 @@ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft override def afterEach() { try { resetSparkContext() + JobCancellationSuite.taskStartedSemaphore.drainPermits() + JobCancellationSuite.taskCancelledSemaphore.drainPermits() + JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits() + JobCancellationSuite.executionOfInterruptibleCounter.set(0) } finally { super.afterEach() } @@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft f2.get() } + test("interruptible iterator of shuffle reader") { + // In this test case, we create a Spark job of two stages. The second stage is cancelled during + // execution and a counter is used to make sure that the corresponding tasks are indeed + // cancelled. + import JobCancellationSuite._ + sc = new SparkContext("local[2]", "test interruptible iterator") + + val taskCompletedSem = new Semaphore(0) + + sc.addSparkListener(new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + // release taskCancelledSemaphore when cancelTasks event has been posted + if (stageCompleted.stageInfo.stageId == 1) { + taskCancelledSemaphore.release(1000) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.stageId == 1) { // make sure tasks are completed + taskCompletedSem.release() + } + } + }) + + val f = sc.parallelize(1 to 1000).map { i => (i, i) } + .repartitionAndSortWithinPartitions(new HashPartitioner(1)) + .mapPartitions { iter => + taskStartedSemaphore.release() + iter + }.foreachAsync { x => + if (x._1 >= 10) { + // This block of code is partially executed. It will be blocked when x._1 >= 10 and the + // next iteration will be cancelled if the source iterator is interruptible. Then in this + // case, the maximum num of increment would be 10(|1...10|) + taskCancelledSemaphore.acquire() + } + executionOfInterruptibleCounter.getAndIncrement() + } + + taskStartedSemaphore.acquire() + // Job is cancelled when: + // 1. task in reduce stage has been started, guaranteed by previous line. + // 2. task in reduce stage is blocked after processing at most 10 records as + // taskCancelledSemaphore is not released until cancelTasks event is posted + // After job being cancelled, task in reduce stage will be cancelled and no more iteration are + // executed. + f.cancel() + + val e = intercept[SparkException](f.get()).getCause + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + + // Make sure tasks are indeed completed. + taskCompletedSem.acquire() + assert(executionOfInterruptibleCounter.get() <= 10) + } + def testCount() { // Cancel before launching any tasks { @@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft object JobCancellationSuite { + // To avoid any headaches, reset these global variables in the companion class's afterEach block val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) val twoJobsSharingStageSemaphore = new Semaphore(0) + val executionOfInterruptibleCounter = new AtomicInteger(0) } From 508573958dc9b6402e684cd6dd37202deaaa97f6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 5 Mar 2018 15:03:27 -0800 Subject: [PATCH 124/593] [SPARK-23538][CORE] Remove custom configuration for SSL client. These options were used to configure the built-in JRE SSL libraries when downloading files from HTTPS servers. But because they were also used to set up the now (long) removed internal HTTPS file server, their default configuration chose convenience over security by having overly lenient settings. This change removes the configuration options that affect the JRE SSL libraries. The JRE trust store can still be configured via system properties (or globally in the JRE security config). The only lost functionality is not being able to disable the default hostname verifier when using spark-submit, which should be fine since Spark itself is not using https for any internal functionality anymore. I also removed the HTTP-related code from the REPL class loader, since we haven't had a HTTP server for REPL-generated classes for a while. Author: Marcelo Vanzin Closes #20723 from vanzin/SPARK-23538. --- .../org/apache/spark/SecurityManager.scala | 45 ------------ .../scala/org/apache/spark/util/Utils.scala | 15 ---- .../org/apache/spark/SSLSampleConfigs.scala | 68 ------------------- .../apache/spark/SecurityManagerSuite.scala | 45 ------------ docs/security.md | 4 -- .../spark/repl/ExecutorClassLoader.scala | 53 ++------------- 6 files changed, 7 insertions(+), 223 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2519d266879aa..da1c89cd78901 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -256,51 +256,6 @@ private[spark] class SecurityManager( // the default SSL configuration - it will be used by all communication layers unless overwritten private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) - // SSL configuration for the file server. This is used by Utils.setupSecureURLConnection(). - val fileServerSSLOptions = getSSLOptions("fs") - val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { - val trustStoreManagers = - for (trustStore <- fileServerSSLOptions.trustStore) yield { - val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() - - try { - val ks = KeyStore.getInstance(KeyStore.getDefaultType) - ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) - - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(ks) - tmf.getTrustManagers - } finally { - input.close() - } - } - - lazy val credulousTrustStoreManagers = Array({ - logWarning("Using 'accept-all' trust manager for SSL connections.") - new X509TrustManager { - override def getAcceptedIssuers: Array[X509Certificate] = null - - override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} - - override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} - }: TrustManager - }) - - require(fileServerSSLOptions.protocol.isDefined, - "spark.ssl.protocol is required when enabling SSL connections.") - - val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.get) - sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) - - val hostVerifier = new HostnameVerifier { - override def verify(s: String, sslSession: SSLSession): Boolean = true - } - - (Some(sslContext.getSocketFactory), Some(hostVerifier)) - } else { - (None, None) - } - def getSSLOptions(module: String): SSLOptions = { val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d493663f0b168..2e2a4a259e9af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -673,7 +673,6 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } - Utils.setupSecureURLConnection(uc, securityMgr) val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 @@ -2363,20 +2362,6 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } - /** - * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and - * the host verifier from the given security manager. - */ - def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { - urlConnection match { - case https: HttpsURLConnection => - sm.sslSocketFactory.foreach(https.setSSLSocketFactory) - sm.hostnameVerifier.foreach(https.setHostnameVerifier) - https - case connection => connection - } - } - def invoke( clazz: Class[_], obj: AnyRef, diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala deleted file mode 100644 index 33270bec6247c..0000000000000 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -object SSLSampleConfigs { - val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File( - this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath - val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - - val enabledAlgorithms = - // A reasonable set of TLSv1.2 Oracle security provider suites - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "TLS_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + - // and their equivalent names in the IBM Security provider - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "SSL_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" - - def sparkSSLConfig(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", keyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - - def sparkSSLConfigUntrusted(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", untrustedKeyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - -} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 106ece7aed0a4..e357299770a2e 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -370,51 +370,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions("user1") === false) } - test("ssl on setup") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val expectedAlgorithms = Set( - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "TLS_RSA_WITH_AES_256_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "SSL_RSA_WITH_AES_256_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === true) - - assert(securityManager.sslSocketFactory.isDefined === true) - assert(securityManager.hostnameVerifier.isDefined === true) - - assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - } - - test("ssl off setup") { - val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) - - System.setProperty("spark.ssl.configFile", file.getAbsolutePath) - val conf = new SparkConf() - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.sslSocketFactory.isDefined === false) - assert(securityManager.hostnameVerifier.isDefined === false) - } - test("missing secret authentication key") { val conf = new SparkConf().set("spark.authenticate", "true") val mgr = new SecurityManager(conf) diff --git a/docs/security.md b/docs/security.md index 0f384b411812a..913d9df50eb1c 100644 --- a/docs/security.md +++ b/docs/security.md @@ -44,10 +44,6 @@ component-specific configuration namespaces used to override the default setting Config Namespace Component - - spark.ssl.fs - File download client (used to download jars and files from HTTPS-enabled servers). - spark.ssl.ui Spark application Web UI diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 127f67329f266..4dc399827ffed 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,12 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException} -import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream} +import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels -import scala.util.control.NonFatal - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ @@ -30,13 +28,13 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.util.ParentClassLoader /** - * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, used to load classes - * defined by the interpreter when the REPL is used. Allows the user to specify if user class path - * should be first. This class loader delegates getting/finding resources to parent loader, which - * makes sense until REPL never provide resource dynamically. + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. * * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or * the load failed, that it will call the overridden `findClass` function. To avoid the potential @@ -60,7 +58,6 @@ class ExecutorClassLoader( private val fetchFn: (String) => InputStream = uri.getScheme() match { case "spark" => getClassFileInputStreamFromSparkRPC - case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer case _ => val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) getClassFileInputStreamFromFileSystem(fileSystem) @@ -113,42 +110,6 @@ class ExecutorClassLoader( } } - private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), - SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] - // Set the connection timeouts (for testing purposes) - if (httpUrlConnectionTimeoutMillis != -1) { - connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) - connection.setReadTimeout(httpUrlConnectionTimeoutMillis) - } - connection.connect() - try { - if (connection.getResponseCode != 200) { - // Close the error stream so that the connection is eligible for re-use - try { - connection.getErrorStream.close() - } catch { - case ioe: IOException => - logError("Exception while closing error stream", ioe) - } - throw new ClassNotFoundException(s"Class file not found at URL $url") - } else { - connection.getInputStream - } - } catch { - case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => - connection.disconnect() - throw e - } - } - private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) From 7706eea6a8bdcd73e9dde5212368f8825e2f1801 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 5 Mar 2018 15:53:10 -0800 Subject: [PATCH 125/593] [SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add tests The `__del__` method that explicitly detaches the object was moved from `JavaParams` to `JavaWrapper` class, this way model summaries could also be garbage collected in Java. A test case was added to make sure that relevant error messages are thrown after the objects are deleted. I ran pyspark tests agains `pyspark-ml` module `./python/run-tests --python-executables=$(which python) --modules=pyspark-ml` Author: Yogesh Garg Closes #20724 from yogeshg/java_wrapper_memory. --- python/pyspark/ml/tests.py | 39 ++++++++++++++++++++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 ++++---- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 116885969345c..6dee6938d8916 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5ef..5061f6434794a 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. From f6b49f9d1b6f218408197f7272c1999fe3d94328 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 01:37:51 +0100 Subject: [PATCH 126/593] [SPARK-23586][SQL] Add interpreted execution to WrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to WrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20741 from mgaido91/SPARK-23586_2. --- .../sql/catalyst/expressions/objects/objects.scala | 3 +-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 03cc8eaceb4e6..d832fe0a6857c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -422,8 +422,7 @@ case class WrapOption(child: Expression, optType: DataType) override def inputTypes: Seq[AbstractDataType] = optType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = Option(child.eval(input)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d95db5867b19c..d535578a7eb06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -75,4 +75,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23586: WrapOption should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val wrapObject = WrapOption(inputObject, cls) + Seq((1, Some(1)), (null, None)).foreach { case (input, expected) => + checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From 8c5b34c425bda2079a1ff969b12c067f2bb3f18f Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 5 Mar 2018 16:49:24 -0800 Subject: [PATCH 127/593] =?UTF-8?q?[SPARK-23604][SQL]=20Change=20Statistic?= =?UTF-8?q?s.isEmpty=20to=20!Statistics.hasNonNul=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …lValue ## What changes were proposed in this pull request? Parquet 1.9 will change the semantics of Statistics.isEmpty slightly to reflect if the null value count has been set. That breaks a timestamp interoperability test that cares only about whether there are column values present in the statistics of a written file for an INT96 column. Fix by using Statistics.hasNonNullValue instead. ## How was this patch tested? Unit tests continue to pass against Parquet 1.8, and also pass against a Parquet build including PARQUET-1217. Author: Henry Robinson Closes #20740 from henryr/spark-23604. --- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index fbd83a0fa425a..9c75965639d8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -184,7 +184,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // when the data is read back as mentioned above, b/c int96 is unsigned. This // assert makes sure this holds even if we change parquet versions (if eg. there // were ever statistics even on unsigned columns). - assert(columnStats.isEmpty) + assert(!columnStats.hasNonNullValue) } // These queries should return the entire dataset with the conversion applied, From ad640a5affceaaf3979e25848628fb1dfcdf932a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 5 Mar 2018 20:35:14 -0800 Subject: [PATCH 128/593] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The proposed explain format: **[streaming header] [RelationV2/ScanV2] [data source name] [output] [pushed filters] [options]** **streaming header**: if it's a streaming relation, put a "Streaming" at the beginning. **RelationV2/ScanV2**: if it's a logical plan, put a "RelationV2", else, put a "ScanV2" **data source name**: the simple class name of the data source implementation **output**: a string of the plan output attributes **pushed filters**: a string of all the filters that have been pushed to this data source **options**: all the options to create the data source reader. The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == RelationV2 AdvancedDataSourceV2[j#1] == Physical Plan == *(1) ScanV2 AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) ScanV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) ScanV2 MemoryStreamDataSource[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20647 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 2 +- .../sql/kafka010/KafkaContinuousTest.scala | 2 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../datasources/v2/DataSourceV2Relation.scala | 34 +++++-- .../datasources/v2/DataSourceV2ScanExec.scala | 18 +++- .../datasources/v2/DataSourceV2Strategy.scala | 8 +- .../v2/DataSourceV2StringFormat.scala | 94 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 29 +++++- .../continuous/ContinuousExecution.scala | 8 +- .../spark/sql/streaming/StreamSuite.scala | 12 ++- .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 11 +-- 13 files changed, 183 insertions(+), 105 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index f679e9bfc0450..aab8ec42189fb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -60,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 48ac3fc1e8f9d..fa1468a3943c8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index f2b3ff7615e74..e017fd9b84d21 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -124,7 +124,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cc6cb631e3f06..2b282ffae2390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,15 +35,12 @@ case class DataSourceV2Relation( options: Map[String, String], projection: Seq[AttributeReference], filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = { - s"DataSourceV2Relation(source=${source.name}, " + - s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + - s"filters=[${pushedFilters.mkString(", ")}], options=$options)" - } + override def simpleString: String = "RelationV2 " + metadataString override lazy val schema: StructType = reader.readSchema() @@ -107,19 +104,36 @@ case class DataSourceV2Relation( } /** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. + * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true. + * + * Note that, this plan has a mutable reader, so Spark won't apply operator push-down for this plan, + * to avoid making the plan mutable. We should consolidate this plan and [[DataSourceV2Relation]] + * after we figure out how to apply operator push-down for streaming data sources. */ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], + source: DataSourceV2, + options: Map[String, String], reader: DataSourceReader) - extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + override def isStreaming: Boolean = true - override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + override def simpleString: String = "Streaming RelationV2 " + metadataString override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: StreamingDataSourceV2Relation => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..cb691ba297076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,10 +37,23 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, + @transient options: Map[String, String], @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = "ScanV2 " + metadataString + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2ScanExec => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c4e7644683c36..1ac9572de6412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case relation: DataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil - case relation: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala new file mode 100644 index 0000000000000..aed55a429bfd7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A trait that can be used by data source v2 related query plans(both logical and physical), to + * provide a string format of the data source information for explain. + */ +trait DataSourceV2StringFormat { + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The options for this data source reader. + */ + def options: Map[String, String] + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + private def sourceName: String = source match { + case registered: DataSourceRegister => registered.shortName() + case _ => source.getClass.getSimpleName.stripSuffix("$") + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + + if (filters.nonEmpty) { + entries += "Filters" -> filters.mkString("[", ", ", "]") + } + + // TODO: we should only display some standard options like path, table, etc. + if (options.nonEmpty) { + entries += "Options" -> Utils.redact(options).map { + case (k, v) => s"$k=$v" + }.mkString("[", ",", "]") + } + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) + }, " (", ", ", ")") + } else { + "" + } + + s"$sourceName$outputStr$entriesStr" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ff4be9c7ab874..6e231970f4a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} +import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,9 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -97,6 +100,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = dataSourceV2 -> options logInfo(s"Using MicroBatchReader [$reader] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) @@ -419,8 +423,19 @@ class MicroBatchExecution( toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + + val (source, options) = reader match { + // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` + // implementation. We provide a fake one here for explain. + case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + case _ => readerToDataSourceMap.getOrElse(reader, { + FakeDataSourceV2 -> Map.empty[String, String] + }) + } + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } @@ -525,3 +540,7 @@ class MicroBatchExecution( object MicroBatchExecution { val BATCH_ID_KEY = "streaming.sql.batchId" } + +object MemoryStreamDataSource extends DataSourceV2 + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index daebd1dd010ac..1758b3844bd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(source, options, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..c1ec1eba69fb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,20 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 08f722ecb10e5..e44aef09f1f3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -629,8 +629,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r + case r: StreamingExecutionRelation => r.source + case r: StreamingDataSourceV2Relation => r.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..f5884b9c8de12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From e8a259d66dda0d4c76f3af8933676bade8a7451d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 6 Mar 2018 13:55:13 +0100 Subject: [PATCH 129/593] [SPARK-23594][SQL] GetExternalRowField should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `GetExternalRowField`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20746 from maropu/SPARK-23594. --- .../expressions/objects/objects.scala | 14 ++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d832fe0a6857c..97e3ff88858d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1358,11 +1358,19 @@ case class GetExternalRowField( override def dataType: DataType = ObjectType(classOf[Object]) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + override def eval(input: InternalRow): Any = { + val inputRow = child.eval(input).asInstanceOf[Row] + if (inputRow == null) { + throw new RuntimeException("The input external row cannot be null.") + } + if (inputRow.isNullAt(index)) { + throw new RuntimeException(errMsg) + } + inputRow.get(index) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d535578a7eb06..0f376c4b63c15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ @@ -84,4 +85,23 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23594 GetExternalRowField should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") + Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => + checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + } + + // If an input row or a field are null, a runtime exception will be thrown + val errMsg1 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(null))) + }.getMessage + assert(errMsg1 === "The input external row cannot be null.") + + val errMsg2 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) + }.getMessage + assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + } } From 8bceb899dc3220998a4ea4021f3b477f78faaca8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 6 Mar 2018 08:52:28 -0600 Subject: [PATCH 130/593] [SPARK-23601][BUILD] Remove .md5 files from release ## What changes were proposed in this pull request? Remove .md5 files from release artifacts ## How was this patch tested? N/A Author: Sean Owen Closes #20737 from srowen/SPARK-23601. --- dev/create-release/release-build.sh | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index a3579f21fc539..c00b00b845401 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -164,8 +164,6 @@ if [[ "$1" == "package" ]]; then tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ --detach-sig spark-$SPARK_VERSION.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ - spark-$SPARK_VERSION.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION @@ -215,9 +213,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $R_DIST_NAME.asc \ --detach-sig $R_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $R_DIST_NAME > \ - $R_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ $R_DIST_NAME.sha512 @@ -234,9 +229,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $PYTHON_DIST_NAME.asc \ --detach-sig $PYTHON_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $PYTHON_DIST_NAME > \ - $PYTHON_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $PYTHON_DIST_NAME > \ $PYTHON_DIST_NAME.sha512 @@ -247,9 +239,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ - spark-$SPARK_VERSION-bin-$NAME.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 @@ -382,18 +371,11 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there + # this must have .asc and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 4c587eb4887623c839854c1505f495de42898229 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 17:42:17 +0100 Subject: [PATCH 131/593] [SPARK-23590][SQL] Add interpreted execution to CreateExternalRow ## What changes were proposed in this pull request? The PR adds interpreted execution to CreateExternalRow ## How was this patch tested? added UT Author: Marco Gaido Closes #20749 from mgaido91/SPARK-23590. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 ++++-- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 4 +++- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 97e3ff88858d0..721d589709131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1111,8 +1111,10 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val values = children.map(_.eval(input)).toArray + new GenericRowWithSchema(values, schema) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5cc..29f0cc0d991aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,6 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -60,7 +61,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], and MapData. + * Array[Byte], Spread[Double], MapData and Row. */ protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { @@ -88,6 +89,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0f376c4b63c15..50e57737a4612 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{IntegerType, ObjectType} +import org.apache.spark.sql.types._ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -86,6 +86,12 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23590: CreateExternalRow should support interpreted execution") { + val schema = new StructType().add("a", IntegerType).add("b", StringType) + val createExternalRow = CreateExternalRow(Seq(Literal(1), Literal("x")), schema) + checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From 04e71c31603af3a13bc13300df799f003fe185f7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 7 Mar 2018 17:01:29 +0800 Subject: [PATCH 132/593] [MINOR][YARN] Add disable yarn.nodemanager.vmem-check-enabled option to memLimitExceededLogMessage My spark application sometimes will throw `Container killed by YARN for exceeding memory limits`. Even I increased `spark.yarn.executor.memoryOverhead` to 10G, this error still happen. The latest config: memory-config And error message: ``` ExecutorLostFailure (executor 121 exited caused by one of the running tasks) Reason: Container killed by YARN for exceeding memory limits. 30.7 GB of 30 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead. ``` This is because of [Linux glibc >= 2.10 (RHEL 6) malloc may show excessive virtual memory usage](https://www.ibm.com/developerworks/community/blogs/kevgrig/entry/linux_glibc_2_10_rhel_6_malloc_may_show_excessive_virtual_memory_usage?lang=en). So disable `yarn.nodemanager.vmem-check-enabled` looks like a good option as [MapR mentioned ](https://mapr.com/blog/best-practices-yarn-resource-management). This PR add disable `yarn.nodemanager.vmem-check-enabled` option to memLimitExceededLogMessage. More details: https://issues.apache.org/jira/browse/YARN-4714 https://stackoverflow.com/a/31450291 https://stackoverflow.com/a/42091255 After this PR: yarn N/A Author: Yuming Wang Author: Yuming Wang Closes #20735 from wangyum/YARN-4714. Change-Id: Ie10836e2c07b6384d228c3f9e89f802823bd9f16 --- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 506adb363aa90..a537243d641cb 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -736,7 +736,8 @@ private object YarnAllocator { def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) val diag = if (matcher.find()) " " + matcher.group() + "." else "" - ("Container killed by YARN for exceeding memory limits." + diag - + " Consider boosting spark.yarn.executor.memoryOverhead.") + s"Container killed by YARN for exceeding memory limits. $diag " + + "Consider boosting spark.yarn.executor.memoryOverhead or " + + "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." } } From 33c2cb22b3b246a413717042a5f741da04ded69d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Mar 2018 13:10:51 +0100 Subject: [PATCH 133/593] [SPARK-23611][SQL] Add a helper function to check exception for expr evaluation ## What changes were proposed in this pull request? This pr added a helper function in `ExpressionEvalHelper` to check exceptions in all the path of expression evaluation. ## How was this patch tested? Modified the existing tests. Author: Takeshi Yamamuro Closes #20748 from maropu/SPARK-23611. --- .../expressions/ExpressionEvalHelper.scala | 83 ++++++++++++++----- .../expressions/MathExpressionsSuite.scala | 2 +- .../expressions/MiscExpressionsSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- .../expressions/ObjectExpressionsSuite.scala | 17 ++-- .../expressions/RegexpExpressionsSuite.scala | 8 +- .../expressions/StringExpressionsSuite.scala | 2 +- .../expressions/TimeWindowSuite.scala | 2 +- 8 files changed, 79 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 29f0cc0d991aa..58d0c07622eb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException @@ -45,11 +47,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } - protected def checkEvaluation( - expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + } + + protected def checkEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -95,7 +101,31 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + inputRow: InternalRow, + expectedErrMsg: String): Unit = { + + def checkException(eval: => Unit, testMode: String): Unit = { + withClue(s"($testMode)") { + val errMsg = intercept[T] { + eval + }.getMessage + if (errMsg != expectedErrMsg) { + fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") + } + } + } + val expr = prepareEvaluation(expression) + checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") + checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode") + } + } + + protected def evaluateWithoutCodegen( + expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { case n: Nondeterministic => n.initialize(0) case _ => @@ -124,7 +154,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!checkResult(actual, expected, expression.dataType)) { @@ -139,33 +169,29 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = evaluateWithGeneratedMutableProjection(expression, inputRow) + if (!checkResult(actual, expected, expression.dataType)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + private def evaluateWithGeneratedMutableProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected, expression.dataType)) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") - } + plan(inputRow).get(0, expression.dataType) } protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // SPARK-16489 Explicitly doing code generation twice so code gen will fail if - // some expression is reusing variable names across different instances. - // This behavior is tested in ExpressionEvalHelperSuite. - val plan = generateProject( - UnsafeProjection.create( - Alias(expression, s"Optimized($expression)1")() :: - Alias(expression, s"Optimized($expression)2")() :: Nil), - expression) - - val unsafeRow = plan(inputRow) + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -185,6 +211,21 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + private def evaluateWithUnsafeProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): InternalRow = { + // SPARK-16489 Explicitly doing code generation twice so code gen will fail if + // some expression is reusing variable names across different instances. + // This behavior is tested in ExpressionEvalHelperSuite. + val plan = generateProject( + UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil), + expression) + + plan(inputRow) + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, @@ -294,7 +335,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try { - evaluate(expr, inputRow) + evaluateWithoutCodegen(expr, inputRow) } catch { case e: Exception => fail(s"Exception evaluating $expr", e) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 39e0060d41dd4..3a094079380fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -124,7 +124,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkNaNWithoutCodegen( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!actual.asInstanceOf[Double].isNaN) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index facc863081303..a21c139fe71d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -41,6 +41,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("uuid") { checkEvaluation(Length(Uuid()), 36) - assert(evaluate(Uuid()) !== evaluate(Uuid())) + assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index cc6c15cb2c909..424c3a4696077 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -51,7 +51,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 50e57737a4612..cbfbb6573ae8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -100,14 +100,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // If an input row or a field are null, a runtime exception will be thrown - val errMsg1 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(null))) - }.getMessage - assert(errMsg1 === "The input external row cannot be null.") - - val errMsg2 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) - }.getMessage - assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(null)), + "The input external row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(Row(null))), + "The 0th field 'c0' of input row cannot be null.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2a0a42c65b086..d532dc4f77198 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -100,12 +100,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // invalid escaping val invalidEscape = intercept[AnalysisException] { - evaluate("""a""" like """\a""") + evaluateWithoutCodegen("""a""" like """\a""") } assert(invalidEscape.getMessage.contains("pattern")) val endEscape = intercept[AnalysisException] { - evaluate("""a""" like """a\""") + evaluateWithoutCodegen("""a""" like """a\""") } assert(endEscape.getMessage.contains("pattern")) @@ -147,11 +147,11 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") + evaluateWithoutCodegen("abbbbc" rlike "**") } intercept[java.util.regex.PatternSyntaxException] { val regex = 'a.string.at(0) - evaluate("abbbbc" rlike regex, create_row("**")) + evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 97ddbeba2c5ca..9a1a4da074ce3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -756,7 +756,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { - evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), + evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), Literal("QUERY"), Literal("???")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index d6c8fcf291842..351d4d0c2eac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -27,7 +27,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva test("time window is unevaluable") { intercept[UnsupportedOperationException] { - evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + evaluateWithoutCodegen(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) } } From aff7d81cb73133483fc2256ca10e21b4b8101647 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 7 Mar 2018 18:31:59 +0100 Subject: [PATCH 134/593] [SPARK-23591][SQL] Add interpreted execution to EncodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to EncodeUsingSerializer. ## How was this patch tested? added UT Author: Marco Gaido Closes #20751 from mgaido91/SPARK-23591. --- .../expressions/objects/objects.scala | 114 ++++++++++-------- .../expressions/ObjectExpressionsSuite.scala | 16 ++- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 721d589709131..7bbc3c732e782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -105,6 +105,61 @@ trait InvokeLike extends Expression with NonSQLExpression { } } +/** + * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]] + */ +trait SerializerSupport { + /** + * If true, Kryo serialization is used, otherwise the Java one is used + */ + val kryo: Boolean + + /** + * The serializer instance to be used for serialization/deserialization in interpreted execution + */ + lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo) + + /** + * Adds a immutable state to the generated class containing a reference to the serializer. + * @return a string containing the name of the variable referencing the serializer + */ + def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + val (serializerInstance, serializerInstanceClass) = { + if (kryo) { + ("kryoSerializer", + classOf[KryoSerializerInstance].getName) + } else { + ("javaSerializer", + classOf[JavaSerializerInstance].getName) + } + } + val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" + // Code to initialize the serializer + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + s""" + |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); + """.stripMargin) + serializerInstance + } +} + +object SerializerSupport { + /** + * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is + * `useKryo` is set to `true`) or a `JavaSerializerInstance`. + */ + def newSerializer(useKryo: Boolean): SerializerInstance = { + // try conf from env, otherwise create a new one + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val s = if (useKryo) { + new KryoSerializer(conf) + } else { + new JavaSerializer(conf) + } + s.newInstance() + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -1154,36 +1209,14 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to serialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) @@ -1207,33 +1240,10 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index cbfbb6573ae8e..346b13277c709 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -109,4 +110,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(null))), "The 0th field 'c0' of input row cannot be null.") } + + test("SPARK-23591: EncodeUsingSerializer should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val expected = serializer.newInstance().serialize(new Integer(1)).array() + val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) + checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) + checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 53561d27c45db31893bcabd4aca2387fde869b72 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Mar 2018 09:37:42 -0800 Subject: [PATCH 135/593] [SPARK-23291][SQL][R] R's substr should not reduce starting position by 1 when calling Scala API ## What changes were proposed in this pull request? Seems R's substr API treats Scala substr API as zero based and so subtracts the given starting position by 1. Because Scala's substr API also accepts zero-based starting position (treated as the first element), so the current R's substr test results are correct as they all use 1 as starting positions. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20464 from viirya/SPARK-23291. --- R/pkg/R/column.R | 10 ++++++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 1 + docs/sparkr.md | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9727efc354f10..7926a9a2467ee 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -161,12 +161,18 @@ setMethod("alias", #' @aliases substr,Column-method #' #' @param x a Column. -#' @param start starting position. +#' @param start starting position. It should be 1-base. #' @param stop ending position. +#' @examples +#' \dontrun{ +#' df <- createDataFrame(list(list(a="abcdef"))) +#' collect(select(df, substr(df$a, 1, 4))) # the result is `abcd`. +#' collect(select(df, substr(df$a, 2, 4))) # the result is `bcd`. +#' } #' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { - jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + jc <- callJMethod(x@jc, "substr", as.integer(start), as.integer(stop - start + 1)) column(jc) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bd0a0dcd0674c..439191adb23ea 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1651,6 +1651,7 @@ test_that("string operators", { expect_false(first(select(df, startsWith(df$name, "m")))[[1]]) expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(first(select(df, substr(df$name, 4, 6)))[[1]], "hae") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { expect_true(startsWith("Hello World", "Hello")) expect_false(endsWith("Hello World", "a")) diff --git a/docs/sparkr.md b/docs/sparkr.md index 6685b585a393a..2909247e79e95 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -663,3 +663,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. + +## Upgrading to Spark 2.4.0 + + - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. From c99fc9ad9b600095baba003053dbf84304ca392b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 7 Mar 2018 13:42:06 -0800 Subject: [PATCH 136/593] [SPARK-23550][CORE] Cleanup `Utils`. A few different things going on: - Remove unused methods. - Move JSON methods to the only class that uses them. - Move test-only methods to TestUtils. - Make getMaxResultSize() a config constant. - Reuse functionality from existing libraries (JRE or JavaUtils) where possible. The change also includes changes to a few tests to call `Utils.createTempFile` correctly, so that temp dirs are created under the designated top-level temp dir instead of potentially polluting git index. Author: Marcelo Vanzin Closes #20706 from vanzin/SPARK-23550. --- .../scala/org/apache/spark/TestUtils.scala | 26 +++- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 4 +- .../spark/internal/config/ConfigBuilder.scala | 3 +- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/TaskSetManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 124 +++++++++-------- .../scala/org/apache/spark/util/Utils.scala | 131 +----------------- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../scala/org/apache/spark/DriverSuite.scala | 2 +- .../spark/deploy/SparkSubmitSuite.scala | 18 +-- .../spark/scheduler/ReplayListenerSuite.scala | 12 +- .../org/apache/spark/util/UtilsSuite.scala | 1 + scalastyle-config.xml | 2 +- .../datasources/orc/OrcSourceSuite.scala | 4 +- .../metric/SQLMetricsTestUtils.scala | 6 +- .../sql/sources/PartitionedWriteSuite.scala | 15 +- .../HiveThriftServer2Suites.scala | 2 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 20 +-- .../spark/streaming/CheckpointSuite.scala | 2 +- .../spark/streaming/MapWithStateSuite.scala | 2 +- 21 files changed, 152 insertions(+), 236 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 93e7ee3d2a404..b5c4c705dcbc7 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,7 +22,7 @@ import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate -import java.util.Arrays +import java.util.{Arrays, Properties} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ @@ -35,6 +35,7 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} +import org.apache.log4j.PropertyConfigurator import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -256,6 +257,29 @@ private[spark] object TestUtils { s"Can't find $numExecutors executors before $timeout milliseconds elapsed") } + /** + * config a log4j properties used for testsuite + */ + def configTestLog4j(level: String): Unit = { + val pro = new Properties() + pro.put("log4j.rootLogger", s"$level, console") + pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") + pro.put("log4j.appender.console.target", "System.err") + pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") + pro.put("log4j.appender.console.layout.ConversionPattern", + "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") + PropertyConfigurator.configure(pro) + } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9db7a1fe3106d..e7796d4ddbe34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{ByteArrayOutputStream, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream} import java.lang.reflect.InvocationTargetException import java.net.URI import java.nio.charset.StandardCharsets @@ -233,7 +233,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { - name = Utils.stripDirectory(primaryResource) + name = new File(primaryResource).getName() } // Action should be SUBMIT unless otherwise specified diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2c3a8ef74800b..dcec3ec21b546 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -35,6 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} @@ -141,8 +142,7 @@ private[spark] class Executor( conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20), RpcUtils.maxMessageSizeBytes(conf)) - // Limit of bytes for total size of results (default is 1GB) - private val maxResultSize = Utils.getMaxResultSize(conf) + private val maxResultSize = conf.get(MAX_RESULT_SIZE) // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b0cd7110a3b47..f27aca03773a9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -23,6 +23,7 @@ import java.util.regex.PatternSyntaxException import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.util.Utils private object ConfigHelpers { @@ -45,7 +46,7 @@ private object ConfigHelpers { } def stringToSeq[T](str: String, converter: String => T): Seq[T] = { - str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + Utils.stringToSeq(str).map(converter) } def seqToString[T](v: Seq[T], stringConverter: T => String): String = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bbfcfbaa7363c..a313ad0554a3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -520,4 +520,9 @@ package object config { .checkValue(v => v > 0, "The threshold should be positive.") .createWithDefault(10000000) + private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize") + .doc("Size limit for results.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1g") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 886c2c99f1ff3..d958658527f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -64,8 +64,7 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) - // Limit of bytes for total size of results (default is 1GB) - val maxResultSize = Utils.getMaxResultSize(conf) + val maxResultSize = conf.get(config.MAX_RESULT_SIZE) val speculationEnabled = conf.getBoolean("spark.speculation", false) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ff83301d631c4..40383fe05026b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -48,7 +48,7 @@ import org.apache.spark.storage._ * To ensure that we provide these guarantees, follow these rules when modifying these methods: * * - Never delete any JSON fields. - * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * - Any new JSON fields should be optional; use `jsonOption` when reading these fields * in `*FromJson` methods. */ private[spark] object JsonProtocol { @@ -408,7 +408,7 @@ private[spark] object JsonProtocol { ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => ("Kill Reason" -> taskKilled.reason) - case _ => Utils.emptyJson + case _ => emptyJson } ("Reason" -> reason) ~ json } @@ -422,7 +422,7 @@ private[spark] object JsonProtocol { def jobResultToJson(jobResult: JobResult): JValue = { val result = Utils.getFormattedClassName(jobResult) val json = jobResult match { - case JobSucceeded => Utils.emptyJson + case JobSucceeded => emptyJson case jobFailed: JobFailed => JObject("Exception" -> exceptionToJson(jobFailed.exception)) } @@ -573,7 +573,7 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -586,7 +586,7 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -597,11 +597,11 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] val submissionTime = - Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 - val stageInfos = Utils.jsonOption(json \ "Stage Infos") + val stageInfos = jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") @@ -613,7 +613,7 @@ private[spark] object JsonProtocol { def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] val completionTime = - Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") SparkListenerJobEnd(jobId, completionTime, jobResult) } @@ -630,15 +630,15 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) - val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val maxOnHeapMem = jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) SparkListenerBlockManagerRemoved(time, blockManagerId) } @@ -648,11 +648,11 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] - val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) + val appId = jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - val appAttemptId = Utils.jsonOption(json \ "App Attempt ID").map(_.extract[String]) - val driverLogs = Utils.jsonOption(json \ "Driver Logs").map(mapFromJson) + val appAttemptId = jsonOption(json \ "App Attempt ID").map(_.extract[String]) + val driverLogs = jsonOption(json \ "Driver Logs").map(mapFromJson) SparkListenerApplicationStart(appName, appId, time, sparkUser, appAttemptId, driverLogs) } @@ -703,19 +703,19 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + val attemptId = jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") - val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) - val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) - val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) + val details = jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") + val submissionTime = jsonOption(json \ "Submission Time").map(_.extract[Long]) + val completionTime = jsonOption(json \ "Completion Time").map(_.extract[Long]) + val failureReason = jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = { - Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -735,17 +735,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) + val attempt = jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] val executorId = (json \ "Executor ID").extract[String].intern() val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) + val speculative = jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) - val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { + val killed = jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -762,13 +762,13 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) - val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } - val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val name = jsonOption(json \ "Name").map(_.extract[String]) + val update = jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } + val internal = jsonOption(json \ "Internal").exists(_.extract[Boolean]) val countFailedValues = - Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) - val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) + jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -821,49 +821,49 @@ private[spark] object JsonProtocol { metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) // Shuffle read metrics - Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => val readMetrics = metrics.createTempShuffleReadMetrics() readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + jsonOption(readJson \ "Remote Bytes Read To Disk") .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( - Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) readMetrics.incRecordsRead( - Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } // Shuffle write metrics // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. - Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) writeMetrics.incRecordsWritten( - Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } // Output metrics - Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) outputMetrics.setRecordsWritten( - Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics - Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) inputMetrics.incRecordsRead( - Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks - Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + jsonOption(json \ "Updated Blocks").foreach { blocksJson => metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") @@ -897,7 +897,7 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + val message = jsonOption(json \ "Message").map(_.extract[String]) new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => @@ -905,9 +905,9 @@ private[spark] object JsonProtocol { val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") val fullStackTrace = - Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull + jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x - val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + val accumUpdates = jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => { acc.toInfo(Some(acc.value), None) @@ -915,21 +915,21 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => - val killReason = Utils.jsonOption(json \ "Kill Reason") + val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility // for reading those logs, we need to provide default values for all the fields. - val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) - val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) - val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + val jobId = jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => - val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) - val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + val exitCausedByApp = jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) + val executorId = jsonOption(json \ "Executor ID").map(_.extract[String]) + val reason = jsonOption(json \ "Loss Reason").map(_.extract[String]) ExecutorLostFailure( executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true), @@ -968,11 +968,11 @@ private[spark] object JsonProtocol { def rddInfoFromJson(json: JValue): RDDInfo = { val rddId = (json \ "RDD ID").extract[Int] val name = (json \ "Name").extract[String] - val scope = Utils.jsonOption(json \ "Scope") + val scope = jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val callsite = jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) val storageLevel = storageLevelFromJson(json \ "Storage Level") @@ -1029,7 +1029,7 @@ private[spark] object JsonProtocol { } def propertiesFromJson(json: JValue): Properties = { - Utils.jsonOption(json).map { value => + jsonOption(json).map { value => val properties = new Properties mapFromJson(json).foreach { case (k, v) => properties.setProperty(k, v) } properties @@ -1058,4 +1058,14 @@ private[spark] object JsonProtocol { e } + /** Return an option that translates JNothing to None */ + private def jsonOption(json: JValue): Option[JValue] = { + json match { + case JNothing => None + case value: JValue => Some(value) + } + } + + private def emptyJson: JObject = JObject(List[JField]()) + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2e2a4a259e9af..29d26ea2c85df 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} +import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -51,9 +51,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException -import org.json4s._ import org.slf4j.Logger import org.apache.spark._ @@ -1017,70 +1015,18 @@ private[spark] object Utils extends Logging { " " + (System.currentTimeMillis - startTimeMs) + " ms" } - private def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - /** - * Lists files recursively. - */ - def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } - /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. * Throws an exception if deletion is unsuccessful. */ - def deleteRecursively(file: File) { + def deleteRecursively(file: File): Unit = { if (file != null) { - try { - if (file.isDirectory && !isSymlink(file)) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - ShutdownHookManager.removeShutdownDeleteDir(file) - } - } finally { - if (file.delete()) { - logTrace(s"${file.getAbsolutePath} has been deleted") - } else { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } + JavaUtils.deleteRecursively(file) + ShutdownHookManager.removeShutdownDeleteDir(file) } } - /** - * Check to see if file is a symbolic link. - */ - def isSymlink(file: File): Boolean = { - return Files.isSymbolicLink(Paths.get(file.toURI)) - } - /** * Determines if a directory contains any files newer than cutoff seconds. * @@ -1828,7 +1774,7 @@ private[spark] object Utils extends Logging { * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower * in the current version of Scala. */ - def getIteratorSize[T](iterator: Iterator[T]): Long = { + def getIteratorSize(iterator: Iterator[_]): Long = { var count = 0L while (iterator.hasNext) { count += 1L @@ -1875,17 +1821,6 @@ private[spark] object Utils extends Logging { obj.getClass.getSimpleName.replace("$", "") } - /** Return an option that translates JNothing to None */ - def jsonOption(json: JValue): Option[JValue] = { - json match { - case JNothing => None - case value: JValue => Some(value) - } - } - - /** Return an empty JSON object */ - def emptyJson: JsonAST.JObject = JObject(List[JField]()) - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -1900,15 +1835,6 @@ private[spark] object Utils extends Logging { getHadoopFileSystem(new URI(path), conf) } - /** - * Return the absolute path of a file in the given directory. - */ - def getFilePath(dir: File, fileName: String): Path = { - assert(dir.isDirectory) - val path = new File(dir, fileName).getAbsolutePath - new Path(path) - } - /** * Whether the underlying operating system is Windows. */ @@ -1931,13 +1857,6 @@ private[spark] object Utils extends Logging { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } - /** - * Strip the directory from a path name - */ - def stripDirectory(path: String): String = { - new File(path).getName - } - /** * Terminates a process waiting for at most the specified duration. * @@ -2348,36 +2267,6 @@ private[spark] object Utils extends Logging { org.apache.log4j.Logger.getRootLogger().setLevel(l) } - /** - * config a log4j properties used for testsuite - */ - def configTestLog4j(level: String): Unit = { - val pro = new Properties() - pro.put("log4j.rootLogger", s"$level, console") - pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") - pro.put("log4j.appender.console.target", "System.err") - pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") - pro.put("log4j.appender.console.layout.ConversionPattern", - "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") - PropertyConfigurator.configure(pro) - } - - def invoke( - clazz: Class[_], - obj: AnyRef, - methodName: String, - args: (Class[_], AnyRef)*): AnyRef = { - val (types, values) = args.unzip - val method = clazz.getDeclaredMethod(methodName, types: _*) - method.setAccessible(true) - method.invoke(obj, values.toSeq: _*) - } - - // Limit of bytes for total size of results (default is 1GB) - def getMaxResultSize(conf: SparkConf): Long = { - memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 - } - /** * Return the current system LD_LIBRARY_PATH name */ @@ -2610,16 +2499,6 @@ private[spark] object Utils extends Logging { SignalUtils.registerLogger(log) } - /** - * Unions two comma-separated lists of files and filters out empty strings. - */ - def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { - var allFiles = Set.empty[String] - leftList.foreach { value => allFiles ++= value.split(",") } - rightList.foreach { value => allFiles ++= value.split(",") } - allFiles.filter { _.nonEmpty } - } - /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 24a55df84a240..0d5c5ea7903e9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -95,7 +95,7 @@ public void tearDown() { @SuppressWarnings("unchecked") public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir("test", "test"); + tempDir = Utils.createTempDir(null, "test"); mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 962945e5b6bb1..896cd2e80aaef 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits { */ object DriverWithoutCleanup { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 803a38d77fb82..d265643a80b4e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.TestUtils import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ @@ -761,18 +762,6 @@ class SparkSubmitSuite } } - test("comma separated list of files are unioned correctly") { - val left = Option("/tmp/a.jar,/tmp/b.jar") - val right = Option("/tmp/c.jar,/tmp/a.jar") - val emptyString = Option("") - Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar")) - Utils.unionFileLists(emptyString, emptyString) should be (Set.empty) - Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) - } - test("support glob path") { val tmpJarDir = Utils.createTempDir() val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) @@ -1042,6 +1031,7 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1076,7 +1066,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits { object JarCreationTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -1100,7 +1090,7 @@ object JarCreationTest extends Logging { object SimpleApplicationTest { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 73e7b3fe8c1de..e24d550a62665 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -47,7 +47,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Simple replay") { - val logFilePath = Utils.getFilePath(testDir, "events.txt") + val logFilePath = getFilePath(testDir, "events.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, @@ -97,7 +97,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // scalastyle:on println } - val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress") + val logFilePath = getFilePath(testDir, "events.lz4.inprogress") val bytes = buffered.toByteArray Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream => fstream.write(bytes, 0, buffered.size) @@ -129,7 +129,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Replay incompatible event log") { - val logFilePath = Utils.getFilePath(testDir, "incompatible.txt") + val logFilePath = getFilePath(testDir, "incompatible.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Incompatible App", None, @@ -226,6 +226,12 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } } + private def getFilePath(dir: File, fileName: String): Path = { + assert(dir.isDirectory) + val path = new File(dir, fileName).getAbsolutePath + new Path(path) + } + /** * A simple listener that buffers all the events it receives. * diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index eaea6b030c154..3b4273184f1e9 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -648,6 +648,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") + sourceDir.mkdir() val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e2fa5754afaee..e65e3aafe5b5b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -229,7 +229,7 @@ This file is divided into 3 sections: extractOpt - Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter is slower. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 523f7cf77e103..8a3bbd03a26dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -39,8 +39,8 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { protected override def beforeAll(): Unit = { super.beforeAll() - orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - orcTableDir = Utils.createTempDir("orctests", "sparksql") + orcTableAsDir = Utils.createTempDir(namePrefix = "orctests") + orcTableDir = Utils.createTempDir(namePrefix = "orctests") sparkContext .makeRDD(1 to 10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 122d28798136f..534d8bb629b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -21,13 +21,13 @@ import java.io.File import scala.collection.mutable.HashMap +import org.apache.spark.TestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils trait SQLMetricsTestUtils extends SQLTestUtils { @@ -91,7 +91,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) .write.format(dataFormat).mode("overwrite").insertInto(tableName) } - assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } } @@ -121,7 +121,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { .mode("overwrite") .insertInto(tableName) } - assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 0fe33e87318a5..27c983f270bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.TestUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils @@ -86,15 +87,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -106,7 +107,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -133,14 +134,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -148,7 +149,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 496f8c82a6c61..b32c547cefefe 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -804,7 +804,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected var metastorePath: File = _ protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ private var logTailingProcess: Process = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 2d31781132edc..079fe45860544 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -330,7 +330,7 @@ class HiveSparkSubmitSuite object SetMetastoreURLTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true) val builder = SparkSession.builder() @@ -368,7 +368,7 @@ object SetMetastoreURLTest extends Logging { object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = @@ -447,7 +447,7 @@ object SetWarehouseLocationTest extends Logging { // can load the jar defined with the function. object TemporaryHiveUDFTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -485,7 +485,7 @@ object TemporaryHiveUDFTest extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest1 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -523,7 +523,7 @@ object PermanentHiveUDFTest1 extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest2 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -558,7 +558,7 @@ object PermanentHiveUDFTest2 extends Logging { // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set("spark.ui.enabled", "false") @@ -628,7 +628,7 @@ object SparkSubmitClassLoaderTest extends Logging { // We test if we can correctly set spark sql configurations when HiveContext is used. object SparkSQLConfTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") // We override the SparkConf to add spark.sql.hive.metastore.version and // spark.sql.hive.metastore.jars to the beginning of the conf entry array. // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but @@ -669,7 +669,7 @@ object SPARK_9757 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( @@ -718,7 +718,7 @@ object SPARK_11009 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -749,7 +749,7 @@ object SPARK_14244 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ee2fd45a7e851..19b621f11759d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -97,7 +97,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => val batchDurationMillis = batchDuration.milliseconds // Setup the stream computation - val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + val checkpointDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString logDebug(s"Using checkpoint directory $checkpointDir") val ssc = createContextForCheckpointOperation(batchDuration) require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 3b662ec1833aa..06c0c2aa97ee1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -39,7 +39,7 @@ class MapWithStateSuite extends SparkFunSuite before { StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir("checkpoint") + checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") } after { From ac76eff6a88f6358a321b84cb5e60fb9d6403419 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 7 Mar 2018 13:51:44 -0800 Subject: [PATCH 137/593] [SPARK-23525][SQL] Support ALTER TABLE CHANGE COLUMN COMMENT for external hive table ## What changes were proposed in this pull request? The following query doesn't work as expected: ``` CREATE EXTERNAL TABLE ext_table(a STRING, b INT, c STRING) PARTITIONED BY (d STRING) LOCATION 'sql/core/spark-warehouse/ext_table'; ALTER TABLE ext_table CHANGE a a STRING COMMENT "new comment"; DESC ext_table; ``` The comment of column `a` is not updated, that's because `HiveExternalCatalog.doAlterTable` ignores table schema changes. To fix the issue, we should call `doAlterTableDataSchema` instead of `doAlterTable`. ## How was this patch tested? Updated `DDLSuite.testChangeColumn`. Author: Xingbo Jiang Closes #20696 from jiangxb1987/alterColumnComment. --- .../spark/sql/execution/command/ddl.scala | 12 ++++++------ .../sql-tests/inputs/change-column.sql | 1 + .../sql-tests/results/change-column.sql.out | 19 ++++++++++++++----- .../sql/execution/command/DDLSuite.scala | 1 + 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 964cbca049b27..bf4d96fa18d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -314,8 +314,8 @@ case class AlterTableChangeColumnCommand( val resolver = sparkSession.sessionState.conf.resolver DDLUtils.verifyAlterTableType(catalog, table, isView = false) - // Find the origin column from schema by column name. - val originColumn = findColumnByName(table.schema, columnName, resolver) + // Find the origin column from dataSchema by column name. + val originColumn = findColumnByName(table.dataSchema, columnName, resolver) // Throw an AnalysisException if the column name/dataType is changed. if (!columnEqual(originColumn, newColumn, resolver)) { throw new AnalysisException( @@ -324,7 +324,7 @@ case class AlterTableChangeColumnCommand( s"'${newColumn.name}' with type '${newColumn.dataType}'") } - val newSchema = table.schema.fields.map { field => + val newDataSchema = table.dataSchema.fields.map { field => if (field.name == originColumn.name) { // Create a new column from the origin column with the new comment. addComment(field, newColumn.getComment) @@ -332,8 +332,7 @@ case class AlterTableChangeColumnCommand( field } } - val newTable = table.copy(schema = StructType(newSchema)) - catalog.alterTable(newTable) + catalog.alterTableDataSchema(tableName, StructType(newDataSchema)) Seq.empty[Row] } @@ -345,7 +344,8 @@ case class AlterTableChangeColumnCommand( schema.fields.collectFirst { case field if resolver(field.name, name) => field }.getOrElse(throw new AnalysisException( - s"Invalid column reference '$name', table schema is '${schema}'")) + s"Can't find column `$name` given table data columns " + + s"${schema.fieldNames.mkString("[`", "`, `", "`]")}")) } // Add the comment to a column, if comment is empty, return the original column. diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index ad0f885f63d3d..2909024e4c9f7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column -- Change column in partition spec (not supported yet) CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d); ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT; +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c79..ff1ecbcc44c23 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 33 -- !query 0 @@ -154,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))'; +Can't find column `invalid_col` given table data columns [`a`, `b`, `c`]; -- !query 16 @@ -291,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT -- !query 30 -DROP TABLE test_change +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C' -- !query 30 schema struct<> -- !query 30 output - +org.apache.spark.sql.AnalysisException +Can't find column `c` given table data columns [`a`, `b`]; -- !query 31 -DROP TABLE partition_table +DROP TABLE test_change -- !query 31 schema struct<> -- !query 31 output + + +-- !query 32 +DROP TABLE partition_table +-- !query 32 schema +struct<> +-- !query 32 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index db9023b7ec8b6..4041176262426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1597,6 +1597,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Ensure that change column will preserve other metadata fields. sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") assert(getMetadata("col1").getString("key") == "value") + assert(getMetadata("col1").getString("comment") == "this is col1") } test("drop build-in function") { From 77c91cc746f93e609c412f3a220495d9e931f696 Mon Sep 17 00:00:00 2001 From: jx158167 Date: Wed, 7 Mar 2018 20:08:32 -0800 Subject: [PATCH 138/593] [SPARK-23524] Big local shuffle blocks should not be checked for corruption. ## What changes were proposed in this pull request? In current code, all local blocks will be checked for corruption no matter it's big or not. The reasons are as below: Size in FetchResult for local block is set to be 0 (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L327) SPARK-4105 meant to only check the small blocks(size Closes #20685 from jinxing64/SPARK-23524. --- .../storage/ShuffleBlockFetcherIterator.scala | 14 +++--- .../ShuffleBlockFetcherIteratorSuite.scala | 45 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a4529..dd9df74689a13 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() @@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - bytesInFlight -= size + if (!localBlocks.contains(blockId)) { + bytesInFlight -= size + } if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) @@ -583,8 +587,8 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block successfully. * @param blockId block id * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5bfe9905ff17b..692ae3bf597e0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("big blocks are not checked for corruption") { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + doReturn(10000L).when(corruptBuffer).size() + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ) + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ) + + val transfer = createMockTransfer( + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths), + (remoteBmId, remoteBlockLengths) + ) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + // Blocks should be returned without exceptions. + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) From fe22f32041572596a22e5f7441fa0bfbd9608648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Mar 2018 10:50:09 +0100 Subject: [PATCH 139/593] [SPARK-23620] Splitting thread dump lines by using the br tag ## What changes were proposed in this pull request? I propose to replace `'\n'` by the `
    ` tag in generated html of thread dump page. The `
    ` tag will split thread lines in more reliable way. For now it could look like on the screen shot if the html is proxied and `'\n'` is replaced by another whitespace. The changes allow to more easily read and copy stack traces. ## How was this patch tested? I tested it manually by checking the thread dump page and its source. Author: Maxim Gekk Closes #20762 from MaxGekk/br-thread-dump. --- .../org/apache/spark/status/api/v1/api.scala | 24 ++++++++++++++++++- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 369e98b683b1a..971d7e90fa7b8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -19,6 +19,8 @@ package org.apache.spark.status.api.v1 import java.lang.{Long => JLong} import java.util.Date +import scala.xml.{NodeSeq, Text} + import com.fasterxml.jackson.annotation.JsonIgnoreProperties import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -317,11 +319,31 @@ class RuntimeInfo private[spark]( val javaHome: String, val scalaVersion: String) +case class StackTrace(elems: Seq[String]) { + override def toString: String = elems.mkString + + def html: NodeSeq = { + val withNewLine = elems.foldLeft(NodeSeq.Empty) { (acc, elem) => + if (acc.isEmpty) { + acc :+ Text(elem) + } else { + acc :+
    :+ Text(elem) + } + } + + withNewLine + } + + def mkString(start: String, sep: String, end: String): String = { + elems.mkString(start, sep, end) + } +} + case class ThreadStackTrace( val threadId: Long, val threadName: String, val threadState: Thread.State, - val stackTrace: String, + val stackTrace: StackTrace, val blockedByThreadId: Option[Long], val blockedByLock: String, val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 7a9aaf29a8b05..9bb026c60565e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,7 +60,7 @@ private[ui] class ExecutorThreadDumpPage( {thread.threadName} {thread.threadState} {blockedBy}{heldLocks} - {thread.stackTrace} + {thread.stackTrace.html} } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 29d26ea2c85df..5caedeb526469 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -61,7 +61,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.status.api.v1.ThreadStackTrace +import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2118,14 +2118,14 @@ private[spark] object Utils extends Logging { private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => + val stackTrace = StackTrace(threadInfo.getStackTrace.map { frame => monitors.get(frame) match { case Some(monitor) => monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" case None => frame.toString } - }.mkString("\n") + }) // use a set to dedup re-entrant locks that are held at multiple places val heldLocks = From 9bb239c8b174d31981dfff63baa38bb8cecfe913 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Mar 2018 20:19:55 +0900 Subject: [PATCH 140/593] [SPARK-23159][PYTHON] Update cloudpickle to v0.4.3 ## What changes were proposed in this pull request? The version of cloudpickle in PySpark was close to version 0.4.0 with some additional backported fixes and some minor additions for Spark related things. This update removes Spark related changes and matches cloudpickle [v0.4.3](https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.3): Changes by updating to 0.4.3 include: * Fix pickling of named tuples https://github.com/cloudpipe/cloudpickle/pull/113 * Built in type constructors for PyPy compatibility [here](https://github.com/cloudpipe/cloudpickle/commit/d84980ccaafc7982a50d4e04064011f401f17d1b) * Fix memoryview support https://github.com/cloudpipe/cloudpickle/pull/122 * Improved compatibility with other cloudpickle versions https://github.com/cloudpipe/cloudpickle/pull/128 * Several cleanups https://github.com/cloudpipe/cloudpickle/pull/121 and [here](https://github.com/cloudpipe/cloudpickle/commit/c91aaf110441991307f5097f950764079d0f9652) * [MRG] Regression on pickling classes from the __main__ module https://github.com/cloudpipe/cloudpickle/pull/149 * BUG: Handle instance methods of builtin types https://github.com/cloudpipe/cloudpickle/pull/154 * Fix #129 : do not silence RuntimeError in dump() https://github.com/cloudpipe/cloudpickle/pull/153 ## How was this patch tested? Existing pyspark.tests using python 2.7.14, 3.5.2, 3.6.3 Author: Bryan Cutler Closes #20373 from BryanCutler/pyspark-update-cloudpickle-42-SPARK-23159. --- python/pyspark/accumulators.py | 1 - python/pyspark/cloudpickle.py | 320 ++++++++++++++------------------- python/pyspark/serializers.py | 14 +- 3 files changed, 151 insertions(+), 184 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc747..7def676b89a24 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,7 +94,6 @@ else: import socketserver as SocketServer import threading -from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 40e91a2d0655d..ea845b98b3db2 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -57,7 +57,6 @@ import types import weakref -from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -181,6 +180,32 @@ def _builtin_type(name): return getattr(types, name) +def _make__new__factory(type_): + def _factory(): + return type_.__new__ + return _factory + + +# NOTE: These need to be module globals so that they're pickleable as globals. +_get_dict_new = _make__new__factory(dict) +_get_frozenset_new = _make__new__factory(frozenset) +_get_list_new = _make__new__factory(list) +_get_set_new = _make__new__factory(set) +_get_tuple_new = _make__new__factory(tuple) +_get_object_new = _make__new__factory(object) + +# Pre-defined set of builtin_function_or_method instances that can be +# serialized. +_BUILTIN_TYPE_CONSTRUCTORS = { + dict.__new__: _get_dict_new, + frozenset.__new__: _get_frozenset_new, + set.__new__: _get_set_new, + list.__new__: _get_list_new, + tuple.__new__: _get_tuple_new, + object.__new__: _get_object_new, +} + + if sys.version_info < (3, 4): def _walk_global_ops(code): """ @@ -237,28 +262,16 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - except pickle.PickleError: - raise - except Exception as e: - emsg = _exception_message(e) - if "'i' format requires" in emsg: - msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) - print_exec(sys.stderr) - raise pickle.PicklingError(msg) - + raise def save_memoryview(self, obj): - """Fallback to save_string""" - Pickler.save_string(self, str(obj)) + self.save(obj.tobytes()) + dispatch[memoryview] = save_memoryview - def save_buffer(self, obj): - """Fallback to save_string""" - Pickler.save_string(self,str(obj)) - if PY3: - dispatch[memoryview] = save_memoryview - else: + if not PY3: + def save_buffer(self, obj): + self.save(str(obj)) dispatch[buffer] = save_buffer def save_unsupported(self, obj): @@ -318,6 +331,24 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ + try: + should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS + except TypeError: + # Methods of builtin types aren't hashable in python 2. + should_special_case = False + + if should_special_case: + # We keep a special-cased cache of built-in type constructors at + # global scope, because these functions are structured very + # differently in different python versions and implementations (for + # example, they're instances of types.BuiltinFunctionType in + # CPython, but they're ordinary types.FunctionType instances in + # PyPy). + # + # If the function we've received is in that cache, we just + # serialize it as a lookup into the cache. + return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj) + write = self.write if name is None: @@ -344,7 +375,7 @@ def save_function(self, obj, name=None): return self.save_global(obj, name) # a builtin_function_or_method which comes in as an attribute of some - # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # object (e.g., itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. @@ -352,16 +383,13 @@ def save_function(self, obj, name=None): # for different python versions. if not hasattr(obj, '__code__'): if PY3: - if sys.version_info < (3, 4): - raise pickle.PicklingError("Can't pickle %r" % obj) - else: - rv = obj.__reduce_ex__(self.proto) + rv = obj.__reduce_ex__(self.proto) else: if hasattr(obj, '__self__'): rv = (getattr, (obj.__self__, name)) else: raise pickle.PicklingError("Can't pickle %r" % obj) - return Pickler.save_reduce(self, obj=obj, *rv) + return self.save_reduce(obj=obj, *rv) # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a @@ -420,20 +448,18 @@ def save_dynamic_class(self, obj): from global modules. """ clsdict = dict(obj.__dict__) # copy dict proxy to a dict - if not isinstance(clsdict.get('__dict__', None), property): - # don't extract dict that are properties - clsdict.pop('__dict__', None) - clsdict.pop('__weakref__', None) - - # hack as __new__ is stored differently in the __dict__ - new_override = clsdict.get('__new__', None) - if new_override: - clsdict['__new__'] = obj.__new__ - - # namedtuple is a special case for Spark where we use the _load_namedtuple function - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return + clsdict.pop('__weakref__', None) + + # On PyPy, __doc__ is a readonly attribute, so we need to include it in + # the initial skeleton class. This is safe because we know that the + # doc can't participate in a cycle with the original class. + type_kwargs = {'__doc__': clsdict.pop('__doc__', None)} + + # If type overrides __dict__ as a property, include it in the type kwargs. + # In Python 2, we can't set this attribute after construction. + __dict__ = clsdict.pop('__dict__', None) + if isinstance(__dict__, property): + type_kwargs['__dict__'] = __dict__ save = self.save write = self.write @@ -453,23 +479,12 @@ def save_dynamic_class(self, obj): # Push the rehydration function. save(_rehydrate_skeleton_class) - # Mark the start of the args for the rehydration function. + # Mark the start of the args tuple for the rehydration function. write(pickle.MARK) - # On PyPy, __doc__ is a readonly attribute, so we need to include it in - # the initial skeleton class. This is safe because we know that the - # doc can't participate in a cycle with the original class. - doc_dict = {'__doc__': clsdict.pop('__doc__', None)} - - # Create and memoize an empty class with obj's name and bases. - save(type(obj)) - save(( - obj.__name__, - obj.__bases__, - doc_dict, - )) - write(pickle.REDUCE) - self.memoize(obj) + # Create and memoize an skeleton class with obj's name and bases. + tp = type(obj) + self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -522,17 +537,22 @@ def save_function_tuple(self, func): self.memoize(func) # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(dct) - save(func.__module__) - save(closure_values) + state = { + 'globals': f_globals, + 'defaults': defaults, + 'dict': dct, + 'module': func.__module__, + 'closure_values': closure_values, + } + if hasattr(func, '__qualname__'): + state['qualname'] = func.__qualname__ + save(state) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple _extract_code_globals_cache = ( weakref.WeakKeyDictionary() - if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") + if not hasattr(sys, "pypy_version_info") else {}) @classmethod @@ -608,37 +628,22 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = '__main__' + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) - if modname == '__main__': - themodule = None - else: - __import__(modname) - themodule = sys.modules[modname] - self.modules.add(themodule) + try: + return Pickler.save_global(self, obj, name=name) + except Exception: + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - if hasattr(themodule, name) and getattr(themodule, name) is obj: - return Pickler.save_global(self, obj, name) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + return self.save_dynamic_class(obj) - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - self.save_dynamic_class(obj) - else: - raise pickle.PicklingError("Can't pickle %r" % obj) + raise dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -709,12 +714,7 @@ def save_property(self, obj): dispatch[property] = save_property def save_classmethod(self, obj): - try: - orig_func = obj.__func__ - except AttributeError: # Python 2.6 - orig_func = obj.__get__(None, object) - if isinstance(obj, classmethod): - orig_func = orig_func.__func__ # Unbind + orig_func = obj.__func__ self.save_reduce(type(obj), (orig_func,), obj=obj) dispatch[classmethod] = save_classmethod dispatch[staticmethod] = save_classmethod @@ -754,64 +754,6 @@ def __getattribute__(self, item): if type(operator.attrgetter) is type: dispatch[operator.attrgetter] = save_attrgetter - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): - # Assert that args is a tuple or None - if not isinstance(args, tuple): - raise pickle.PicklingError("args from reduce() should be a tuple") - - # Assert that func is callable - if not hasattr(func, '__call__'): - raise pickle.PicklingError("func from reduce should be callable") - - save = self.save - write = self.write - - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - save(state) - write(pickle.BUILD) - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): # 2.7 supports partial pickling - dispatch[partial] = save_partial - - def save_file(self, obj): """Save a file""" try: @@ -867,23 +809,21 @@ def save_not_implemented(self, obj): dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented - # WeakSet was added in 2.7. - if hasattr(weakref, 'WeakSet'): - def save_weakset(self, obj): - self.save_reduce(weakref.WeakSet, (list(obj),)) - - dispatch[weakref.WeakSet] = save_weakset + def save_weakset(self, obj): + self.save_reduce(weakref.WeakSet, (list(obj),)) - """Special functions for Add-on libraries""" - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - pass + dispatch[weakref.WeakSet] = save_weakset def save_logger(self, obj): self.save_reduce(logging.getLogger, (obj.name,), obj=obj) dispatch[logging.Logger] = save_logger + """Special functions for Add-on libraries""" + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + pass + # Tornado support @@ -913,11 +853,12 @@ def dump(obj, file, protocol=2): def dumps(obj, protocol=2): file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - return file.getvalue() + try: + cp = CloudPickler(file,protocol) + cp.dump(obj) + return file.getvalue() + finally: + file.close() # including pickles unloading functions in this namespace load = pickle.load @@ -1019,18 +960,40 @@ def __reduce__(cls): return cls.__name__ -def _fill_function(func, globals, defaults, dict, module, closure_values): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). +def _fill_function(*args): + """Fills in the rest of function data into the skeleton function object + + The skeleton itself is create by _make_skel_func(). """ - func.__globals__.update(globals) - func.__defaults__ = defaults - func.__dict__ = dict - func.__module__ = module + if len(args) == 2: + func = args[0] + state = args[1] + elif len(args) == 5: + # Backwards compat for cloudpickle v0.4.0, after which the `module` + # argument was introduced + func = args[0] + keys = ['globals', 'defaults', 'dict', 'closure_values'] + state = dict(zip(keys, args[1:])) + elif len(args) == 6: + # Backwards compat for cloudpickle v0.4.1, after which the function + # state was passed as a dict to the _fill_function it-self. + func = args[0] + keys = ['globals', 'defaults', 'dict', 'module', 'closure_values'] + state = dict(zip(keys, args[1:])) + else: + raise ValueError('Unexpected _fill_value arguments: %r' % (args,)) + + func.__globals__.update(state['globals']) + func.__defaults__ = state['defaults'] + func.__dict__ = state['dict'] + if 'module' in state: + func.__module__ = state['module'] + if 'qualname' in state: + func.__qualname__ = state['qualname'] cells = func.__closure__ if cells is not None: - for cell, value in zip(cells, closure_values): + for cell, value in zip(cells, state['closure_values']): if value is not _empty_cell_value: cell_set(cell, value) @@ -1087,13 +1050,6 @@ def _find_module(mod_name): file.close() return path, description -def _load_namedtuple(name, fields): - """ - Loads a class generated by namedtuple - """ - from collections import namedtuple - return namedtuple(name, fields) - """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 91a7f093cec19..917e258d8a602 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ xrange = range from pyspark import cloudpickle +from pyspark.util import _exception_message __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] @@ -565,7 +566,18 @@ def loads(self, obj, encoding=None): class CloudPickleSerializer(PickleSerializer): def dumps(self, obj): - return cloudpickle.dumps(obj, 2) + try: + return cloudpickle.dumps(obj, 2) + except pickle.PickleError: + raise + except Exception as e: + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg + else: + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) + cloudpickle.print_exec(sys.stderr) + raise pickle.PicklingError(msg) class MarshalSerializer(FramedSerializer): From d6632d185e147fcbe6724545488ad80dce20277e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Mar 2018 20:22:07 +0900 Subject: [PATCH 141/593] [SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame ## What changes were proposed in this pull request? This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame. ## How was this patch tested? Manually tested and unit tests added. You can test this by: **`createDataFrame`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame(pdf, "a: map") ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame(pdf, "a: map") ``` **`toPandas`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` Author: hyukjinkwon Closes #20678 from HyukjinKwon/SPARK-23380-conf. --- docs/sql-programming-guide.md | 5 + python/pyspark/sql/dataframe.py | 120 ++++++++++++------ python/pyspark/sql/session.py | 22 +++- python/pyspark/sql/tests.py | 84 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 13 +- 5 files changed, 186 insertions(+), 58 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 01e2076555ee6..451b814ab6c53 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da `createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. +In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically +to non-Arrow optimization implementation if an error occurs before the actual computation within Spark. +This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'. +
    {% include_example dataframe_with_arrow python/sql/arrow.py %} @@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9d8e85cde914f..8f90a367e8bf8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1992,55 +1992,91 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + use_arrow = True try: - from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps, to_arrow_schema + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() - import pyarrow to_arrow_schema(self.schema) - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) - return _check_dataframe_localize_timestamps(pdf, timezone) - else: - return pd.DataFrame.from_records([], columns=self.columns) except Exception as e: - msg = ( - "Note: toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " - "to disable this.") - raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) - else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - dtype = {} + if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False + else: + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + import pyarrow + + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed unexpectedly:\n %s\n" + "Note that 'spark.sql.execution.arrow.fallback.enabled' does " + "not have an effect in such failure in the middle of " + "computation." % _exception_message(e)) + raise RuntimeError(msg) + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3af9b82953f3..215bb3e5c5173 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - # Fallback to create DataFrame without arrow if raise some exception + from pyspark.util import _exception_message + + if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fa3b7203e10ac..a9fe0b425ad3e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,7 +32,9 @@ import datetime import array import ctypes +import warnings import py4j +from contextlib import contextmanager try: import xmlrunner @@ -48,12 +50,13 @@ else: import unittest +from pyspark.util import _exception_message + _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: - from pyspark.util import _exception_message # If Pandas version requirement is not satisfied, skip related tests. _pandas_requirement_message = _exception_message(e) @@ -62,7 +65,6 @@ from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: - from pyspark.util import _exception_message # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) @@ -195,6 +197,28 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3458,6 +3482,8 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() - df = self.spark.createDataFrame([(None,)], schema="a binary") - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): - df.toPandas() - def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce3f94618edeb..3f96112659c11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1058,7 +1058,7 @@ object SQLConf { .intConf .createWithDefault(100) - val ARROW_EXECUTION_ENABLE = + val ARROW_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas, and " + @@ -1068,6 +1068,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .booleanConf + .createWithDefault(true) + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") .doc("When using Apache Arrow, limit the maximum number of records that can be written " + @@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging { def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED) + + def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) From 2cb23a8f51a151970c121015fcbad9beeafa8295 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Mar 2018 20:29:07 +0900 Subject: [PATCH 142/593] [SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF ## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key. --- python/pyspark/serializers.py | 18 ++- python/pyspark/sql/functions.py | 25 ++++ python/pyspark/sql/tests.py | 121 ++++++++++++++++-- python/pyspark/sql/types.py | 45 ++++++- python/pyspark/sql/udf.py | 19 +-- python/pyspark/util.py | 16 +++ python/pyspark/worker.py | 49 +++++-- .../python/FlatMapGroupsInPandasExec.scala | 56 +++++++- 8 files changed, 294 insertions(+), 55 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 917e258d8a602..ebf549396f463 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -250,6 +250,15 @@ def __init__(self, timezone): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.types import from_arrow_type, \ + _check_series_convert_date, _check_series_localize_timestamps + + s = arrow_column.to_pandas() + s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _check_series_localize_timestamps(s, self._timezone) + return s + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -272,16 +281,11 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) - schema = from_arrow_schema(reader.schema) + for batch in reader: - pdf = batch.to_pandas() - pdf = _check_dataframe_convert_date(pdf, schema) - pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) - yield [c for _, c in pdf.iteritems()] + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b9c0c57262c5d..dc1341ac74d3d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 1.1094003924504583| +---+-------------------+ + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key will be passed as the first argument and the data will + be passed as the second argument. The grouping key will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key in the function. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a9fe0b425ad3e..480815d27333f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3903,7 +3903,7 @@ def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v): + def foo(k, v, w): return k @@ -4476,20 +4476,45 @@ def test_supported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col df = self.data.withColumn("arr", array(col("id"))) - foo_udf = pandas_udf( + # Different forms of group map pandas UDF, results of these are the same + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), + StructField('v1', DoubleType()), + StructField('v2', LongType())]) + + udf1 = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]), + output_schema, PandasUDFType.GROUPED_MAP ) - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) + udf2 = pandas_udf( + lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4648,6 +4673,80 @@ def test_timestamp_dst(self): result = df.groupby('time').apply(foo_udf).sort('time') self.assertPandasEqual(df.toPandas(), result.toPandas()) + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cd857402db8f7..1632862d3f1ba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_series_convert_date(series, data_type): + """ + Cast the series to datetime.date if it's a date type, otherwise returns the original series. + + :param series: pandas.Series + :param data_type: a Spark data type for the series + """ + if type(data_type) == DateType: + return series.dt.date + else: + return series + + def _check_dataframe_convert_date(pdf, schema): """ Correct date type value to use datetime.date. @@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema): :param schema: a Spark schema of the pandas.DataFrame """ for field in schema: - if type(field.dataType) == DateType: - pdf[field.name] = pdf[field.name].dt.date + pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) return pdf @@ -1725,6 +1737,29 @@ def _get_local_timezone(): return os.environ.get('TZ', 'dateutil/:') +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) + pdf[column] = _check_series_localize_timestamps(series, timezone) return pdf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b9b490874f4fb..ce804c18e9b14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,6 +17,8 @@ """ User-defined function related classes and functions """ +import sys +import inspect import functools from pyspark import SparkContext, since @@ -24,6 +26,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - import inspect - import sys from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() - if sys.version_info[0] < 3: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - argspec = inspect.getargspec(f) - else: - argspec = inspect.getfullargspec(f) + argspec = _get_argspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: @@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType): "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (1, 2): raise ValueError( "Invalid function: pandas_udfs with function type GROUPED_MAP " - "must take a single arg that is a pandas DataFrame." - ) + "must take either one argument (data) or two arguments (key, data).") # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ad4a0bc68ef41..6837b18b7d7a5 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import sys +import inspect from py4j.protocol import Py4JJavaError __all__ = [] @@ -45,6 +48,19 @@ def _exception_message(excp): return str(excp) +def _get_argspec(f): + """ + Get argspec of a function. Supports both Python 2 and Python 3. + """ + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + if sys.version_info[0] < 3: + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) + return argspec + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 89a3a92bc66d6..202cac350aafc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import _get_argspec from pyspark import shuffle pickleSer = PickleSerializer() @@ -91,10 +92,16 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type): - def wrapped(*series): + def wrapped(key_series, value_series): import pandas as pd + argspec = _get_argspec(f) + + if len(argspec.args) == 1: + result = f(pd.concat(value_series, axis=1)) + elif len(argspec.args) == 2: + key = tuple(s[0] for s in key_series) + result = f(key, pd.concat(value_series, axis=1)) - result = f(pd.concat(series, axis=1)) if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "pandas.DataFrame, but is {}".format(type(result))) @@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] - for i in range(num_udfs): + mapper_str = "" + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # Create function like this: + # lambda a: f([a[0]], [a[0], a[1]]) + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs['f'] = udf + split_offset = arg_offsets[0] + 1 + arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] + arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + else: + # Create function like this: + # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c798fe5a92c54..513e174c7733e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) } } @@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 7013eea11cb32b1e0038dc751c485da5c94a484b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 8 Mar 2018 20:38:34 +0900 Subject: [PATCH 143/593] [SPARK-23522][PYTHON] always use sys.exit over builtin exit The exit() builtin is only for interactive use. applications should use sys.exit(). ## What changes were proposed in this pull request? All usage of the builtin `exit()` function is replaced by `sys.exit()`. ## How was this patch tested? I ran `python/run-tests`. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Benjamin Peterson Closes #20682 from benjaminp/sys-exit. --- dev/merge_spark_pr.py | 2 +- dev/run-tests.py | 2 +- examples/src/main/python/avro_inputformat.py | 2 +- examples/src/main/python/kmeans.py | 2 +- examples/src/main/python/logistic_regression.py | 2 +- examples/src/main/python/ml/dataframe_example.py | 2 +- examples/src/main/python/mllib/correlations.py | 2 +- examples/src/main/python/mllib/kmeans.py | 2 +- examples/src/main/python/mllib/logistic_regression.py | 2 +- examples/src/main/python/mllib/random_rdd_generation.py | 2 +- examples/src/main/python/mllib/sampled_rdds.py | 4 ++-- .../python/mllib/streaming_linear_regression_example.py | 2 +- examples/src/main/python/pagerank.py | 2 +- examples/src/main/python/parquet_inputformat.py | 2 +- examples/src/main/python/sort.py | 2 +- .../main/python/sql/streaming/structured_kafka_wordcount.py | 2 +- .../python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- .../src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/hdfs_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- .../main/python/streaming/recoverable_network_wordcount.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- examples/src/main/python/wordcount.py | 2 +- python/pyspark/accumulators.py | 2 +- python/pyspark/broadcast.py | 2 +- python/pyspark/conf.py | 2 +- python/pyspark/context.py | 2 +- python/pyspark/daemon.py | 2 +- python/pyspark/find_spark_home.py | 2 +- python/pyspark/heapq3.py | 3 ++- python/pyspark/ml/classification.py | 3 ++- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/image.py | 4 +++- python/pyspark/ml/linalg/__init__.py | 2 +- python/pyspark/ml/recommendation.py | 4 +++- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/stat.py | 4 +++- python/pyspark/ml/tuning.py | 6 ++++-- python/pyspark/mllib/classification.py | 3 ++- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/evaluation.py | 3 ++- python/pyspark/mllib/feature.py | 2 +- python/pyspark/mllib/fpm.py | 4 +++- python/pyspark/mllib/linalg/__init__.py | 2 +- python/pyspark/mllib/linalg/distributed.py | 2 +- python/pyspark/mllib/random.py | 3 ++- python/pyspark/mllib/recommendation.py | 3 ++- python/pyspark/mllib/regression.py | 6 ++++-- python/pyspark/mllib/stat/_statistics.py | 2 +- python/pyspark/mllib/tree.py | 3 ++- python/pyspark/mllib/util.py | 2 +- python/pyspark/profiler.py | 3 ++- python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 2 +- python/pyspark/shuffle.py | 3 ++- python/pyspark/sql/catalog.py | 3 ++- python/pyspark/sql/column.py | 2 +- python/pyspark/sql/conf.py | 4 +++- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 4 +++- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/streaming.py | 2 +- python/pyspark/sql/types.py | 2 +- python/pyspark/sql/udf.py | 3 ++- python/pyspark/sql/window.py | 2 +- python/pyspark/streaming/util.py | 3 ++- python/pyspark/util.py | 4 +++- python/pyspark/worker.py | 6 +++--- python/setup.py | 6 +++--- 79 files changed, 120 insertions(+), 86 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 6b244d8184b2c..5ea205fbed4aa 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -510,7 +510,7 @@ def main(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) try: main() except: diff --git a/dev/run-tests.py b/dev/run-tests.py index fe75ef4411c8c..164c1e2200aa9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -621,7 +621,7 @@ def _test(): import doctest failure_count = doctest.testmod()[0] if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 6286ba6541fbd..a18722c687f8b 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -61,7 +61,7 @@ Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 92e0a3ae2ee60..a42d711fc505f 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -49,7 +49,7 @@ def closestPoint(p, centers): if len(sys.argv) != 4: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 01c938454b108..bcc4e0f4e8eae 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -48,7 +48,7 @@ def readPointBatch(iterator): if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index 109f901012c9c..d62cf2338a1fe 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -33,7 +33,7 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Usage: dataframe_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) elif len(sys.argv) == 2: input = sys.argv[1] else: diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 0e13546b88e67..089504fa7064b 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -31,7 +31,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: correlations ()", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: filepath = sys.argv[1] diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 002fc75799648..1bdb3e9b4a2af 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -36,7 +36,7 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index d4f1d34e2d8cf..87efe17375226 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -42,7 +42,7 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 729bae30b152c..9a429b5f8abdf 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: random_rdd_generation", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index b7033ab7daeb3..00e7cf4bbcdbf 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: sampled_rdds ", file=sys.stderr) - exit(-1) + sys.exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] else: @@ -43,7 +43,7 @@ numExamples = examples.count() if numExamples == 0: print("Error: Data file had no samples to load.", file=sys.stderr) - exit(1) + sys.exit(1) print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py index f600496867c11..714c9a0de7217 100644 --- a/examples/src/main/python/mllib/streaming_linear_regression_example.py +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -36,7 +36,7 @@ if len(sys.argv) != 3: print("Usage: streaming_linear_regression_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0d6c253d397a0..2c19e8700ab16 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -47,7 +47,7 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: pagerank ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("WARN: This is a naive implementation of PageRank and is given as an example!\n" + "Please refer to PageRank implementation provided by graphx", diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index a3f86cf8999cf..83041f0040a0c 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -45,7 +45,7 @@ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 81898cf6d5ce6..d3cd985d197e3 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -25,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py index 9e8a552b3b10b..921067891352a 100644 --- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -49,7 +49,7 @@ print(""" Usage: structured_kafka_wordcount.py """, file=sys.stderr) - exit(-1) + sys.exit(-1) bootstrapServers = sys.argv[1] subscribeType = sys.argv[2] diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index c3284c1d01017..9ac392164735b 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -38,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: structured_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index db672551504b5..c4e3bbf44cd5a 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -53,7 +53,7 @@ msg = ("Usage: structured_network_wordcount_windowed.py " " []") print(msg, file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 425df309011a0..c5c186c11f79a 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") ssc = StreamingContext(sc, 2) diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 5d6e6dc36d6f9..c8ea92b61ca6e 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: flume_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingFlumeWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f815dd26823d1..f9a5c43a8eaa9 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: hdfs_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 704f6602e2297..e9ee08b9fd228 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 9010fafb425e6..f3099d2517cd5 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index d51a380a5d5f9..2b5434c0c845a 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -47,7 +47,7 @@ def print_happiest_words(rdd): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") ssc = StreamingContext(sc, 5) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 52b2639cdf55c..60167dc772544 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -101,7 +101,7 @@ def filterFunc(wordCount): if len(sys.argv) != 5: print("Usage: recoverable_network_wordcount.py " " ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, lambda: createContext(host, int(port), output)) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 7f12281c0e3fe..ab3cfc067994d 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -48,7 +48,7 @@ def getSparkSessionInstance(sparkConf): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: sql_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index d7bb61e729f18..d5d1eba6c5969 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: stateful_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 3d5e44d5b2df1..a05e24ff3ff95 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -26,7 +26,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 7def676b89a24..f730d290273fe 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -265,4 +265,4 @@ def _start_update_server(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 02fc515fb824a..b3dfc99962a35 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -162,4 +162,4 @@ def clear(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 491b3a81972bc..ab429d9ab10de 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -217,7 +217,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 24905f1c97b21..7c664966ed74e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1035,7 +1035,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7f06d4288c872..7bed5216eabf3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -89,7 +89,7 @@ def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) - exit(code) + sys.exit(code) def handle_sigterm(*args): shutdown(1) diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 212a618b767ab..9cf0e8c8d2fe9 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -68,7 +68,7 @@ def is_spark_home(path): return next(path for path in paths if is_spark_home(path)) except StopIteration: print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) - exit(-1) + sys.exit(-1) if __name__ == "__main__": print(_find_spark_home()) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index b27e91a4cc251..6af084adcf373 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -884,6 +884,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": import doctest + import sys (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d3..fbbe3d0307c81 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,6 +16,7 @@ # import operator +import sys from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only @@ -2043,4 +2044,4 @@ def _to_java(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6448b76a0da88..b3d5fb17f6b81 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper @@ -1181,4 +1183,4 @@ def getKeepLastCheckpoint(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 695d8ab27cc96..8eaf07645a37f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys from abc import abstractmethod, ABCMeta from pyspark import since, keyword_only @@ -446,4 +447,4 @@ def getDistanceMeasure(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b07e6a05481..f2e357f0bede5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3717,4 +3717,4 @@ def setSize(self, value): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 45c936645f2a8..96d702f844839 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -24,6 +24,8 @@ :members: """ +import sys + import numpy as np from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string @@ -251,7 +253,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index ad1b487676fa7..6a611a2b5b59d 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1158,7 +1158,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e8bcbe4cd34cb..a8eae9bd268d3 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -480,4 +482,4 @@ def recommendForItemSubset(self, dataset, numUsers): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f0812bd1d4a39..de0a0fa9f3bf8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since, keyword_only @@ -1812,4 +1813,4 @@ def __repr__(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 079b0833e1c6d..0eeb5e528434a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java from pyspark.ml.wrapper import _jvm @@ -151,4 +153,4 @@ def corr(dataset, column, method="pearson"): failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6c0cad6cbaaa1..545e24ca05aa5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,9 +15,11 @@ # limitations under the License. # import itertools -import numpy as np +import sys from multiprocessing.pool import ThreadPool +import numpy as np + from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java @@ -727,4 +729,4 @@ def _to_java(self): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cce703d432b5a..bb281981fd56b 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -16,6 +16,7 @@ # from math import exp +import sys import warnings import numpy @@ -761,7 +762,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index bb687a7da6ffd..0cbabab13a896 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1048,7 +1048,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 2cd1da3fbf9aa..36cb03369b8c0 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since @@ -542,7 +543,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index e5231dc3a27a8..40ecd2e0ff4be 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -819,7 +819,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": sys.path.pop(0) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f58ea5dfb0874..de18dad1f675d 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + import numpy from numpy import array from collections import namedtuple @@ -197,7 +199,7 @@ def _test(): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 7b24b3c74a9fa..60d96d8d5ceb8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1370,7 +1370,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 4cb802514be52..bba88542167ad 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1377,7 +1377,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 61213ddf62e8b..a8833cb446923 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -19,6 +19,7 @@ Python package for random data generation. """ +import sys from functools import wraps from pyspark import since @@ -421,7 +422,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 81182881352bb..3d4eae85132bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,6 +16,7 @@ # import array +import sys from collections import namedtuple from pyspark import SparkContext, since @@ -326,7 +327,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index ea107d400621d..6be45f51862c9 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -15,9 +15,11 @@ # limitations under the License. # +import sys +import warnings + import numpy as np from numpy import array -import warnings from pyspark import RDD, since from pyspark.streaming.dstream import DStream @@ -837,7 +839,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 49b26446dbc32..3c75b132ecad2 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -313,7 +313,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 619fa16d463f5..b05734ce489d9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -17,6 +17,7 @@ from __future__ import absolute_import +import sys import random from pyspark import SparkContext, RDD, since @@ -654,7 +655,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 97755807ef262..fc7809387b13a 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -521,7 +521,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 44d17bd629473..3c7656ab5758c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -19,6 +19,7 @@ import pstats import os import atexit +import sys from pyspark.accumulators import AccumulatorParam @@ -173,4 +174,4 @@ def stats(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 93b8974a7e64a..4b44f76747264 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2498,7 +2498,7 @@ def _test(): globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ebf549396f463..15753f77bd903 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -715,4 +715,4 @@ def write_with_length(obj, stream): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e974cda9fc3e1..02c773302e9da 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -23,6 +23,7 @@ import itertools import operator import random +import sys import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ @@ -810,4 +811,4 @@ def load_partition(j): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 6aef0f22340be..b0d8357f4feec 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from collections import namedtuple @@ -306,7 +307,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 43b38a2cd477c..e05a7b33c11a7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -660,7 +660,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 792c420ca6386..d929834aeeaa5 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix @@ -80,7 +82,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index cc1cd1a5842d9..6cb90399dd616 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -543,7 +543,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8f90a367e8bf8..3fc194d8ec1d1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2231,7 +2231,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dc1341ac74d3d..dff590983b4d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2404,7 +2404,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ab646535c864c..35cac406e0965 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal @@ -299,7 +301,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9d05ac7cb39be..803f561ece67b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -970,7 +970,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) sc.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 215bb3e5c5173..e82a9750a0014 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -830,7 +830,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index cc622decfd682..e8966c20a8f42 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -930,7 +930,7 @@ def _test(): globs['spark'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1632862d3f1ba..826aab97e58db 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1890,7 +1890,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index ce804c18e9b14..24dd06c26089c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -20,6 +20,7 @@ import sys import inspect import functools +import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix @@ -397,7 +398,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index bb841a9b9ff7c..e667fba099fb9 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -264,7 +264,7 @@ def _test(): SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..df184471993ff 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,6 +18,7 @@ import time from datetime import datetime import traceback +import sys from pyspark import SparkContext, RDD @@ -147,4 +148,4 @@ def rddToFileName(prefix, suffix, timestamp): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 6837b18b7d7a5..ed1bdd0e4be83 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,6 +22,8 @@ __all__ = [] +import sys + def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -65,4 +67,4 @@ def _get_argspec(f): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 202cac350aafc..a1a4336b1e8de 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -205,7 +205,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - exit(-1) + sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: @@ -279,7 +279,7 @@ def process(): # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - exit(-1) + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -297,7 +297,7 @@ def process(): else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - exit(-1) + sys.exit(-1) if __name__ == '__main__': diff --git a/python/setup.py b/python/setup.py index 6a98401941d8d..794ceceae3008 100644 --- a/python/setup.py +++ b/python/setup.py @@ -26,7 +26,7 @@ if sys.version_info < (2, 7): print("Python versions prior to 2.7 are not supported for pip installed PySpark.", file=sys.stderr) - exit(-1) + sys.exit(-1) try: exec(open('pyspark/version.py').read()) @@ -98,7 +98,7 @@ def _supports_symlinks(): except: print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), file=sys.stderr) - exit(-1) + sys.exit(-1) # If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and # ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. @@ -140,7 +140,7 @@ def _supports_symlinks(): if not os.path.isdir(SCRIPTS_TARGET): print(incorrect_invocation_message, file=sys.stderr) - exit(-1) + sys.exit(-1) # Scripts directive requires a list of each script path and does not take wild cards. script_names = os.listdir(SCRIPTS_TARGET) From 92e7ecbbbd6817378abdbd56541a9c13dcea8659 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 14:18:14 +0100 Subject: [PATCH 144/593] [SPARK-23592][SQL] Add interpreted execution to DecodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to DecodeUsingSerializer. ## How was this patch tested? added UT Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Marco Gaido Closes #20760 from mgaido91/SPARK-23592. --- .../catalyst/expressions/objects/objects.scala | 5 +++++ .../expressions/ObjectExpressionsSuite.scala | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7bbc3c732e782..adf9ddf327c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1242,6 +1242,11 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression with NonSQLExpression with SerializerSupport { + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 346b13277c709..ffeec2a38c532 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row @@ -123,4 +125,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { + val cls = classOf[java.lang.Integer] + val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val input = serializer.newInstance().serialize(new Integer(1)).array() + val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) + checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 3be4adf6485ca19cdc5db23394c3f5a660d7dc6f Mon Sep 17 00:00:00 2001 From: lucio <576632108@qq.com> Date: Thu, 8 Mar 2018 08:03:24 -0600 Subject: [PATCH 145/593] [SPARK-22751][ML] Improve ML RandomForest shuffle performance ## What changes were proposed in this pull request? As I mentioned in [SPARK-22751](https://issues.apache.org/jira/browse/SPARK-22751?jql=project%20%3D%20SPARK%20AND%20component%20%3D%20ML%20AND%20text%20~%20randomforest), there is a shuffle performance problem in ML Randomforest when train a RF in high dimensional data. The reason is that, in _org.apache.spark.tree.impl.RandomForest_, the function _findSplitsBySorting_ will actually flatmap a sparse vector into a dense vector, then in groupByKey there will be a huge shuffle write size. To avoid this, we can add a filter in flatmap, to filter out zero value. And in function _findSplitsForContinuousFeature_, we can infer the number of zero value by _metadata_. In addition, if a feature only contains zero value, _continuousSplits_ will not has the key of feature id. So I add a check when using _continuousSplits_. ## How was this patch tested? Ran model locally using spark-submit. Author: lucio <576632108@qq.com> Closes #20472 from lucio-yz/master. --- .../spark/ml/tree/impl/RandomForest.scala | 52 ++++++++++++++----- .../ml/tree/impl/RandomForestSuite.scala | 23 ++++---- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 8e514f11e78ea..16f32d76b9984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -892,13 +892,7 @@ private[spark] object RandomForest extends Logging { // Sample the input only if there are continuous features. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) val sampledInput = if (continuousFeatures.nonEmpty) { - // Calculate the number of samples for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples - } else { - 1.0 - } + val fraction = samplesFractionForFindSplits(metadata) logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { @@ -920,8 +914,9 @@ private[spark] object RandomForest extends Logging { val numPartitions = math.min(continuousFeatures.length, input.partitions.length) input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) - .groupByKey(numPartitions) + .flatMap { point => + continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) + }.groupByKey(numPartitions) .map { case (idx, samples) => val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) @@ -933,7 +928,8 @@ private[spark] object RandomForest extends Logging { val numFeatures = metadata.numFeatures val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { case i if metadata.isContinuous(i) => - val split = continuousSplits(i) + // some features may contain only zero, so continuousSplits will not have a record + val split = continuousSplits.getOrElse(i, Array.empty[Split]) metadata.setNumSplits(i, split.length) split @@ -1003,11 +999,22 @@ private[spark] object RandomForest extends Logging { } else { val numSplits = metadata.numSplits(featureIndex) - // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { - case ((m, cnt), x) => - (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + // get count for each distinct value except zero value + val partNumSamples = featureSamples.size + val partValueCountMap = scala.collection.mutable.Map[Double, Int]() + featureSamples.foreach { x => + partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 } + + // Calculate the expected number of samples for finding splits + val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + // add expected zero value count and get complete statistics + val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { + partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + } else { + partValueCountMap.toMap + } + // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -1149,4 +1156,21 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + + /** + * Calculate the subsample fraction for finding splits + * + * @param metadata decision tree metadata + * @return subsample fraction + */ + private def samplesFractionForFindSplits( + metadata: DecisionTreeMetadata): Double = { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 5f0d26eb5c058..743dacf146fe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,12 +93,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array.fill(200000)(math.random) + val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) assert(fakeMetadata.numSplits(0) === 5) @@ -109,7 +109,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -117,7 +117,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits <= numSplits { - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -125,7 +125,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { - val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -135,7 +135,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -150,7 +150,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -164,12 +164,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + .map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) @@ -177,12 +178,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits for constant feature { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits === Array.empty[Double]) From ea480990e726aed59750f1cea8d40adba56d991a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 11:09:15 -0800 Subject: [PATCH 146/593] [SPARK-23628][SQL] calculateParamLength should not return 1 + num of epressions ## What changes were proposed in this pull request? There was a bug in `calculateParamLength` which caused it to return always 1 + the number of expressions. This could lead to Exceptions especially with expressions of type long. ## How was this patch tested? added UT + fixed previous UT Author: Marco Gaido Closes #20772 from mgaido91/SPARK-23628. --- .../expressions/codegen/CodeGenerator.scala | 51 ++++++++++--------- .../expressions/CodeGenerationSuite.scala | 6 +++ .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../execution/WholeStageCodegenSuite.scala | 16 +++--- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 793824b0b0a2f..fe5e63ec0a2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1063,31 +1063,6 @@ class CodegenContext { "" } } - - /** - * Returns the length of parameters for a Java method descriptor. `this` contributes one unit - * and a parameter of type long or double contributes two units. Besides, for nullable parameter, - * we also need to pass a boolean parameter for the null status. - */ - def calculateParamLength(params: Seq[Expression]): Int = { - def paramLengthForExpr(input: Expression): Int = { - // For a nullable expression, we need to pass in an extra boolean parameter. - (if (input.nullable) 1 else 0) + javaType(input.dataType) match { - case JAVA_LONG | JAVA_DOUBLE => 2 - case _ => 1 - } - } - // Initial value is 1 for `this`. - 1 + params.map(paramLengthForExpr(_)).sum - } - - /** - * In Java, a method descriptor is valid only if it represents method parameters with a total - * length less than a pre-defined constant. - */ - def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH - } } /** @@ -1538,4 +1513,30 @@ object CodeGenerator extends Logging { def defaultValue(dt: DataType, typedNull: Boolean = false): String = defaultValue(javaType(dt), typedNull) + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + val javaParamLength = javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaParamLength + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e48c7b8df9da..64c13e8972036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addImmutableStateIfNotExists("String", mutableState2) assert(ctx.inlinedMutableStates.length == 2) } + + test("SPARK-23628: calculateParamLength should compute properly the param length") { + assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101) + assert(CodeGenerator.calculateParamLength( + Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f89e3fb0e536f..6ddaacfee1a40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -174,8 +174,9 @@ trait CodegenSupport extends SparkPlan { // declaration. val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator val requireAllOutput = output.forall(parent.usedInputs.contains(_)) - val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) - val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput + && CodeGenerator.isValidParamLength(paramLength)) { constructDoConsumeFunction(ctx, inputVars, row) } else { parent.doConsume(ctx, inputVars, rowVar) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ef16292a8e75c..0fb9dd2017a09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Skip splitting consume function when parameter number exceeds JVM limit") { - import testImplicits._ - - Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + // since every field is nullable we have 2 params for each input column (one for the value + // and one for the isNull variable) + Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) => withTempPath { dir => val path = dir.getCanonicalPath - spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*) .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", @@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val df = spark.read.parquet(path).selectExpr(projection: _*) val plan = df.queryExecution.executedPlan - val wholeStageCodeGenExec = plan.find(p => p match { - case wp: WholeStageCodegenExec => true + val wholeStageCodeGenExec = plan.find { + case _: WholeStageCodegenExec => true case _ => false - }) + } assert(wholeStageCodeGenExec.isDefined) val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 assert(code.body.contains("project_doConsume") == hasSplit) From e7bbca88964d95593fa15eb94643ba519801e352 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 22:02:28 +0100 Subject: [PATCH 147/593] [SPARK-23602][SQL] PrintToStderr prints value also in interpreted mode ## What changes were proposed in this pull request? `PrintToStderr` was doing what is it supposed to only when code generation is enabled. The PR adds the same behavior in interpreted mode too. ## How was this patch tested? added UT Author: Marco Gaido Closes #20773 from mgaido91/SPARK-23602. --- .../spark/sql/catalyst/expressions/misc.scala | 7 +++++- .../expressions/MiscExpressionsSuite.scala | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b423..38e4fe44b15ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,12 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType - protected override def nullSafeEval(input: Any): Any = input + protected override def nullSafeEval(input: Any): Any = { + // scalastyle:off println + System.err.println(outputPrefix + input) + // scalastyle:on println + input + } private val outputPrefix = s"Result of ${child.simpleString} is " diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index a21c139fe71d0..c3d08bf68c7bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.PrintStream + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -43,4 +45,27 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Uuid()), 36) assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } + + test("PrintToStderr") { + val inputExpr = Literal(1) + val systemErr = System.err + + val (outputEval, outputCodegen) = try { + val errorStream = new java.io.ByteArrayOutputStream() + System.setErr(new PrintStream(errorStream)) + // check without codegen + checkEvaluationWithoutCodegen(PrintToStderr(inputExpr), 1) + val outputEval = errorStream.toString + errorStream.reset() + // check with codegen + checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1) + val outputCodegen = errorStream.toString + (outputEval, outputCodegen) + } finally { + System.setErr(systemErr) + } + + assert(outputCodegen.contains(s"Result of $inputExpr is 1")) + assert(outputEval.contains(s"Result of $inputExpr is 1")) + } } From d90e77bd0ec19f8ba9198a24ec2ab3db7708eca8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 8 Mar 2018 14:58:40 -0800 Subject: [PATCH 148/593] [SPARK-23271][SQL] Parquet output contains only _SUCCESS file after writing an empty dataframe ## What changes were proposed in this pull request? Below are the two cases. ``` SQL case 1 scala> List.empty[String].toDF().rdd.partitions.length res18: Int = 1 ``` When we write the above data frame as parquet, we create a parquet file containing just the schema of the data frame. Case 2 ``` SQL scala> val anySchema = StructType(StructField("anyName", StringType, nullable = false) :: Nil) anySchema: org.apache.spark.sql.types.StructType = StructType(StructField(anyName,StringType,false)) scala> spark.read.schema(anySchema).csv("/tmp/empty_folder").rdd.partitions.length res22: Int = 0 ``` For the 2nd case, since number of partitions = 0, we don't call the write task (the task has logic to create the empty metadata only parquet file) The fix is to create a dummy single partition RDD and set up the write task based on it to ensure the metadata-only file. ## How was this patch tested? A new test is added to DataframeReaderWriterSuite. Author: Dilip Biswal Closes #20525 from dilipbiswal/spark-23271. --- docs/sql-programming-guide.md | 1 + .../datasources/FileFormatWriter.scala | 15 ++++++++++++--- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 1 - 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 451b814ab6c53..d2132d2ae7441 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,6 +1805,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1d..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -190,9 +190,18 @@ object FileFormatWriter extends Logging { global = false, child = plan).execute() } - val ret = new Array[WriteTaskResult](rdd.partitions.length) + + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + + val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( - rdd, + rddWithNonEmptyPartitions, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -202,7 +211,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 73e3df3b6202e..bd3071bcf9010 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -89,6 +90,23 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + Seq("orc", "parquet").foreach { format => + test(s"SPARK-23271 empty RDD when saved should write a metadata only file - $format") { + withTempPath { outputPath => + val df = spark.emptyDataFrame.select(lit(1).as("i")) + df.write.format(format).save(outputPath.toString) + val partFiles = outputPath.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + // Now read the file. + val df1 = spark.read.format(format).load(outputPath.toString) + checkAnswer(df1, Seq.empty[Row]) + assert(df1.schema.equals(df.schema.asNullable)) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8c9bb7d56a35f..a707a88dfa670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -301,7 +301,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be intercept[AnalysisException] { spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path) } - spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) } } From 2c3673680e16f88f1d1cd73a3f7445ded5b3daa8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 9 Mar 2018 10:36:38 -0800 Subject: [PATCH 149/593] [SPARK-23630][YARN] Allow user's hadoop conf customizations to take effect. This change restores functionality that was inadvertently removed as part of the fix for SPARK-22372. Also modified an existing unit test to make sure the feature works as intended. Author: Marcelo Vanzin Closes #20776 from vanzin/SPARK-23630. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 11 +++++- .../org/apache/spark/deploy/yarn/Client.scala | 14 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 34 ++++++++++++++----- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e14f9845e6db6..177295fb7af0f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -111,7 +111,9 @@ class SparkHadoopUtil extends Logging { * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { - SparkHadoopUtil.newConfiguration(conf) + val hadoopConf = SparkHadoopUtil.newConfiguration(conf) + hadoopConf.addResource(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE) + hadoopConf } /** @@ -435,6 +437,13 @@ object SparkHadoopUtil { */ private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + /** + * Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the + * cluster's Hadoop config. It is up to the Spark code launching the application to create + * this file if it's desired. If the file doesn't exist, it will just be ignored. + */ + private[spark] val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" + def get: SparkHadoopUtil = instance /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8cd3cd9746a3a..28087dee831d1 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -696,7 +696,13 @@ private[spark] class Client( } } - Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + // SPARK-23630: during testing, Spark scripts filter out hadoop conf dirs so that user's + // environments do not interfere with tests. This allows a special env variable during + // tests so that custom conf dirs can be used by unit tests. + val confDirs = Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR") ++ + (if (Utils.isTesting) Seq("SPARK_TEST_HADOOP_CONF_DIR") else Nil) + + confDirs.foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { @@ -753,7 +759,7 @@ private[spark] class Client( // Save the YARN configuration into a separate file that will be overlayed on top of the // cluster's Hadoop conf. - confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE)) + confStream.putNextEntry(new ZipEntry(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE)) hadoopConf.writeXml(confStream) confStream.closeEntry() @@ -1220,10 +1226,6 @@ private object Client extends Logging { // Name of the file in the conf archive containing Spark configuration. val SPARK_CONF_FILE = "__spark_conf__.properties" - // Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the - // cluster's Hadoop config. - val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" - // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 5003326b440bf..33d400a5b1b2e 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -114,12 +114,25 @@ class YarnClusterSuite extends BaseYarnClusterSuite { )) } - test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") { + test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)") { + // Create a custom hadoop config file, to make sure it's contents are propagated to the driver. + val customConf = Utils.createTempDir() + val coreSite = """ + | + | + | spark.test.key + | testvalue + | + | + |""".stripMargin + Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8) + val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(false, mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), - appArgs = Seq("key=value", result.getAbsolutePath()), - extraConf = Map("spark.hadoop.key" -> "value")) + appArgs = Seq("key=value", "spark.test.key=testvalue", result.getAbsolutePath()), + extraConf = Map("spark.hadoop.key" -> "value"), + extraEnv = Map("SPARK_TEST_HADOOP_CONF_DIR" -> customConf.getAbsolutePath())) checkResult(finalState, result) } @@ -319,13 +332,13 @@ private object YarnClusterDriverWithFailure extends Logging with Matchers { private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers { def main(args: Array[String]): Unit = { - if (args.length != 2) { + if (args.length < 2) { // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file] + |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value]+ [result file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -335,11 +348,16 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn test using SparkHadoopUtil's conf")) - val kv = args(0).split("=") - val status = new File(args(1)) + val kvs = args.take(args.length - 1).map { kv => + val parsed = kv.split("=") + (parsed(0), parsed(1)) + } + val status = new File(args.last) var result = "failure" try { - SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1)) + kvs.foreach { case (k, v) => + SparkHadoopUtil.get.conf.get(k) should be (v) + } result = "success" } finally { Files.write(result, status, StandardCharsets.UTF_8) From 2ca9bb083c515511d2bfee271fc3e0269aceb9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20=C5=9Awitakowski?= Date: Fri, 9 Mar 2018 14:29:31 -0800 Subject: [PATCH 150/593] [SPARK-23173][SQL] Avoid creating corrupt parquet files when loading data from JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The from_json() function accepts an additional parameter, where the user might specify the schema. The issue is that the specified schema might not be compatible with data. In particular, the JSON data might be missing data for fields declared as non-nullable in the schema. The from_json() function does not verify the data against such errors. When data with missing fields is sent to the parquet encoder, there is no verification either. The end results is a corrupt parquet file. To avoid corruptions, make sure that all fields in the user-specified schema are set to be nullable. Since this changes the behavior of a public function, we need to include it in release notes. The behavior can be reverted by setting `spark.sql.fromJsonForceNullableSchema=false` ## How was this patch tested? Added two new tests. Author: Michał Świtakowski Closes #20694 from mswit-databricks/SPARK-23173. --- .../expressions/jsonExpressions.scala | 22 +++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../expressions/JsonExpressionsSuite.scala | 30 ++++++++++++++++++- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++++++++ 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b4fed597447..fdd672c416a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -515,10 +516,15 @@ case class JsonToStructs( child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - override def nullable: Boolean = true - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, None) + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + + // The JSON input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder + // can generate incorrect files if values are missing in columns declared as non-nullable. + val nullableSchema = if (forceNullableSchema) schema.asNullable else schema + + override def nullable: Boolean = true // Used in `FunctionRegistry` def this(child: Expression, schema: Expression) = @@ -535,22 +541,22 @@ case class JsonToStructs( child = child, timeZoneId = None) - override def checkInputDataTypes(): TypeCheckResult = schema match { + override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") } @transient - lazy val rowSchema = schema match { + lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st } // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = schema match { + lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => @@ -563,7 +569,7 @@ case class JsonToStructs( rowSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) - override def dataType: DataType = schema + override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3f96112659c11..11864bd1b1847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -493,6 +493,14 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to curruptions.") + .booleanConf + .createWithDefault(true) + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index a0bbe02f92354..7812319756eae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,11 +22,13 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { val json = """ |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}], @@ -680,4 +682,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } } + + test("from_json missing fields") { + for (forceJsonNullableSchema <- Seq(false, true)) { + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + checkEvaluation( + JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), + output + ) + val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + .dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3af80930ec807..0b3e8ca060d87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -780,6 +781,24 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(option.compressionCodecClassName == "UNCOMPRESSED") } } + + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { + withTempPath { file => + val jsonData = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input") + .write.parquet(file.getAbsolutePath) + checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From 10b0657b035641ce735055bba2c8459e71bc2400 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 9 Mar 2018 15:41:19 -0800 Subject: [PATCH 151/593] [SPARK-23624][SQL] Revise doc of method pushFilters in Datasource V2 ## What changes were proposed in this pull request? Revise doc of method pushFilters in SupportsPushDownFilters/SupportsPushDownCatalystFilters In `FileSourceStrategy`, except `partitionKeyFilters`(the references of which is subset of partition keys), all filters needs to be evaluated after scanning. Otherwise, Spark will get wrong result from data sources like Orc/Parquet. This PR is to improve the doc. Author: Wang Gengliang Closes #20769 from gengliangwang/revise_pushdown_doc. --- .../sql/sources/v2/reader/SupportsPushDownCatalystFilters.java | 2 +- .../spark/sql/sources/v2/reader/SupportsPushDownFilters.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 98224102374aa..290d614805ac7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -34,7 +34,7 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Expression[] pushCatalystFilters(Expression[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index f35c711b0387a..1cff024232a44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -32,7 +32,7 @@ public interface SupportsPushDownFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Filter[] pushFilters(Filter[] filters); From 1a54f48b6744032b16543594651ee6d5e3ad4bda Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Mar 2018 15:54:55 -0800 Subject: [PATCH 152/593] [SPARK-23510][SQL][FOLLOW-UP] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? In the PR https://github.com/apache/spark/pull/20671, I forgot to update the doc about this new support. ## How was this patch tested? N/A Author: gatorsmile Closes #20789 from gatorsmile/docUpdate. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2132d2ae7441..0e092e0e37ccf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2229,7 +2229,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From b6f837c9d3cb0f76f0a52df37e34aea8944f6867 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Sat, 10 Mar 2018 19:48:29 +0900 Subject: [PATCH 153/593] [PYTHON] Changes input variable to not conflict with built-in function Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Changes variable name conflict: [input is a built-in python function](https://stackoverflow.com/questions/20670732/is-input-a-keyword-in-python). ## How was this patch tested? I runned the example and it works fine. Author: DylanGuedes Closes #20775 from DylanGuedes/input_variable. --- examples/src/main/python/ml/dataframe_example.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index d62cf2338a1fe..cabc3de68f2f4 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -17,7 +17,7 @@ """ An example of how to use DataFrame for ML. Run with:: - bin/spark-submit examples/src/main/python/ml/dataframe_example.py + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -35,18 +35,18 @@ print("Usage: dataframe_example.py ", file=sys.stderr) sys.exit(-1) elif len(sys.argv) == 2: - input = sys.argv[1] + input_path = sys.argv[1] else: - input = "data/mllib/sample_libsvm_data.txt" + input_path = "data/mllib/sample_libsvm_data.txt" spark = SparkSession \ .builder \ .appName("DataFrameExample") \ .getOrCreate() - # Load input data - print("Loading LIBSVM file with UDT from " + input + ".") - df = spark.read.format("libsvm").load(input).cache() + # Load an input file + print("Loading LIBSVM file with UDT from " + input_path + ".") + df = spark.read.format("libsvm").load(input_path).cache() print("Schema from LIBSVM:") df.printSchema() print("Loaded training data as a DataFrame with " + From b304e07e0671faf96530f9d8f49c55a83b07fa15 Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Mon, 12 Mar 2018 22:13:28 +0900 Subject: [PATCH 154/593] [SPARK-23462][SQL] improve missing field error message in `StructType` ## What changes were proposed in this pull request? The error message ```s"""Field "$name" does not exist."""``` is thrown when looking up an unknown field in StructType. In the error message, we should also contain the information about which columns/fields exist in this struct. ## How was this patch tested? Added new unit tests. Note: I created a new `StructTypeSuite.scala` as I couldn't find an existing suite that's suitable to place these tests. I may be missing something so feel free to propose new locations. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xiayun Sun Closes #20649 from xysun/SPARK-23462. --- .../apache/spark/sql/types/StructType.scala | 11 +++-- .../spark/sql/types/StructTypeSuite.scala | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d5011c3cb87e9..362676b252126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -271,7 +271,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def apply(name: String): StructField = { nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } /** @@ -284,7 +286,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") + s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -297,7 +300,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala new file mode 100644 index 0000000000000..c6ca8bb005429 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class StructTypeSuite extends SparkFunSuite { + + val s = StructType.fromDDL("a INT, b STRING") + + test("lookup a single missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s("c")).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup a set of missing fields should output existing fields") { + val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup fieldIndex for missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage + assert(e.contains("Available fields: a, b")) + } +} From d5b41aea62201cd5b1baad2f68f5fc7eb99c62c5 Mon Sep 17 00:00:00 2001 From: Jooseong Kim Date: Mon, 12 Mar 2018 11:31:34 -0700 Subject: [PATCH 155/593] [SPARK-23618][K8S][BUILD] Initialize BUILD_ARGS in docker-image-tool.sh ## What changes were proposed in this pull request? This change initializes BUILD_ARGS to an empty array when $SPARK_HOME/RELEASE exists. In function build, "local BUILD_ARGS" effectively creates an array of one element where the first and only element is an empty string, so "${BUILD_ARGS[]}" expands to "" and passes an extra argument to docker. Setting BUILD_ARGS to an empty array makes "${BUILD_ARGS[]}" expand to nothing. ## How was this patch tested? Manually tested. $ cat RELEASE Spark 2.3.0 (git revision a0d7949896) built for Hadoop 2.7.3 Build flags: -Phadoop-2.7 -Phive -Phive-thriftserver -Pkafka-0-8 -Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr -DzincPort=3036 $ ./bin/docker-image-tool.sh -m t testing build Sending build context to Docker daemon 256.4MB ... vanzin Author: Jooseong Kim Closes #20791 from jooseong/SPARK-23618. --- bin/docker-image-tool.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 071406336d1b1..0d0f564bb8b9b 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -57,6 +57,7 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" + BUILD_ARGS=() fi if [ ! -d "$IMG_PATH" ]; then From 567bd31e0ae8b632357baa93e1469b666fb06f3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 12 Mar 2018 14:53:15 -0500 Subject: [PATCH 156/593] [SPARK-23412][ML] Add cosine distance to BisectingKMeans ## What changes were proposed in this pull request? The PR adds the option to specify a distance measure in BisectingKMeans. Moreover, it introduces the ability to use the cosine distance measure in it. ## How was this patch tested? added UTs + existing UTs Author: Marco Gaido Closes #20600 from mgaido91/SPARK-23412. --- .../spark/ml/clustering/BisectingKMeans.scala | 16 +- .../apache/spark/ml/clustering/KMeans.scala | 11 +- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 19 ++ .../mllib/clustering/BisectingKMeans.scala | 139 ++++---- .../clustering/BisectingKMeansModel.scala | 115 +++++-- .../mllib/clustering/DistanceMeasure.scala | 303 ++++++++++++++++++ .../spark/mllib/clustering/KMeans.scala | 196 +---------- .../ml/clustering/BisectingKMeansSuite.scala | 44 ++- project/MimaExcludes.scala | 6 + 10 files changed, 557 insertions(+), 298 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 4c20e6563bad1..f7c422dc0faea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, + BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType} /** * Common params for BisectingKMeans and BisectingKMeansModel */ -private[clustering] trait BisectingKMeansParams extends Params - with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter + with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure { /** * The desired number of leaf clusters. Must be > 1. Default: 4. @@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) @@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") ( .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index c8145de564cbe..987a4285ebad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) - @Since("2.4.0") - final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - (value: String) => MLlibKMeans.validateDistanceMeasure(value)) - - /** @group expertGetParam */ - @Since("2.4.0") - def getDistanceMeasure: String = $(distanceMeasure) - /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 6ad44af9ef7eb..b9c3170cc3c28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen { "after fitting. If set to true, then all sub-models will be available. Warning: For " + "large models, collecting all sub-models can cause OOMs on the Spark driver", Some("false"), isExpertParam = true), - ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false) + ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false), + ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), + isValid = "(value: String) => " + + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index be8b2f273164b..282ea6ebcbf7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -504,4 +504,23 @@ trait HasLoss extends Params { /** @group getParam */ final def getLoss: String = $(loss) } + +/** + * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasDistanceMeasure extends Params { + + /** + * Param for The distance measure. Supported options: 'euclidean' and 'cosine'. + * @group param + */ + final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)) + + setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN) + + /** @group getParam */ + final def getDistanceMeasure: String = $(distanceMeasure) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 2221f4c0edc17..98af487306dcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -57,7 +57,8 @@ class BisectingKMeans private ( private var k: Int, private var maxIterations: Int, private var minDivisibleClusterSize: Double, - private var seed: Long) extends Logging { + private var seed: Long, + private var distanceMeasure: String) extends Logging { import BisectingKMeans._ @@ -65,7 +66,7 @@ class BisectingKMeans private ( * Constructs with the default configuration */ @Since("1.6.0") - def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN) /** * Sets the desired number of leaf clusters (default: 4). @@ -134,6 +135,22 @@ class BisectingKMeans private ( @Since("1.6.0") def getSeed: Long = this.seed + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + DistanceMeasure.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + /** * Runs the bisecting k-means algorithm. * @param input RDD of vectors @@ -147,11 +164,13 @@ class BisectingKMeans private ( } val d = input.map(_.size).first() logInfo(s"Feature dimension: $d.") + + val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure) // Compute and cache vector norms for fast distance computation. val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) - var activeClusters = summarize(d, assignments) + var activeClusters = summarize(d, assignments, dMeasure) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -184,24 +203,25 @@ class BisectingKMeans private ( val divisibleIndices = divisibleClusters.keys.toSet logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => - val (left, right) = splitCenter(summary.center, random) + val (left, right) = splitCenter(summary.center, random, dMeasure) Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map var newClusters: Map[Long, ClusterSummary] = null var newAssignments: RDD[(Long, VectorWithNorm)] = null for (iter <- 0 until maxIterations) { - newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters, + dMeasure) .filter { case (index, _) => divisibleIndices.contains(parentIndex(index)) } - newClusters = summarize(d, newAssignments) + newClusters = summarize(d, newAssignments, dMeasure) newClusterCenters = newClusters.mapValues(_.center).map(identity) } if (preIndices != null) { preIndices.unpersist(false) } preIndices = indices - indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys .persist(StorageLevel.MEMORY_AND_DISK) assignments = indices.zip(vectors) inactiveClusters ++= activeClusters @@ -222,8 +242,8 @@ class BisectingKMeans private ( } norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters - val root = buildTree(clusters) - new BisectingKMeansModel(root) + val root = buildTree(clusters, dMeasure) + new BisectingKMeansModel(root, this.distanceMeasure) } /** @@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable { */ private def summarize( d: Int, - assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { - assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + assignments: RDD[(Long, VectorWithNorm)], + distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))( seqOp = (agg, v) => agg.add(v), combOp = (agg1, agg2) => agg1.merge(agg2) ).mapValues(_.summary) @@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable { * Cluster summary aggregator. * @param d feature dimension */ - private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure) + extends Serializable { private var n: Long = 0L private val sum: Vector = Vectors.zeros(d) private var sumSq: Double = 0.0 @@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable { n += 1L // TODO: use a numerically stable approach to estimate cost sumSq += v.norm * v.norm - BLAS.axpy(1.0, v.vector, sum) + distanceMeasure.updateClusterSum(v, sum) this } @@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable { def merge(other: ClusterSummaryAggregator): this.type = { n += other.n sumSq += other.sumSq - BLAS.axpy(1.0, other.sum, sum) + distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum) this } /** Returns the summary. */ def summary: ClusterSummary = { - val mean = sum.copy - if (n > 0L) { - BLAS.scal(1.0 / n, mean) - } - val center = new VectorWithNorm(mean) - val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) - new ClusterSummary(n, center, cost) + val center = distanceMeasure.centroid(sum.copy, n) + val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq) + ClusterSummary(n, center, cost) } } @@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable { */ private def splitCenter( center: VectorWithNorm, - random: Random): (VectorWithNorm, VectorWithNorm) = { + random: Random, + distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = { val d = center.vector.size val norm = center.norm val level = 1e-4 * norm val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) - val left = center.vector.copy - BLAS.axpy(-level, noise, left) - val right = center.vector.copy - BLAS.axpy(level, noise, right) - (new VectorWithNorm(left), new VectorWithNorm(right)) + distanceMeasure.symmetricCentroids(level, noise, center.vector) } /** @@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable { private def updateAssignments( assignments: RDD[(Long, VectorWithNorm)], divisibleIndices: Set[Long], - newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + newClusterCenters: Map[Long, VectorWithNorm], + distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = { assignments.map { case (index, v) => if (divisibleIndices.contains(index)) { val children = Seq(leftChildIndex(index), rightChildIndex(index)) - val newClusterChildren = children.filter(newClusterCenters.contains(_)) + val newClusterChildren = children.filter(newClusterCenters.contains) + val newClusterChildrenCenterToId = + newClusterChildren.map(id => newClusterCenters(id) -> id).toMap + val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray if (newClusterChildren.nonEmpty) { - val selected = newClusterChildren.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) - } - (selected, v) + val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1 + val center = newClusterChildrenCenters(selected) + val id = newClusterChildrenCenterToId(center) + (id, v) } else { (index, v) } @@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable { * @param clusters a map from cluster indices to corresponding cluster summaries * @return the root node of the clustering tree */ - private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + private def buildTree( + clusters: Map[Long, ClusterSummary], + distanceMeasure: DistanceMeasure): ClusteringTreeNode = { var leafIndex = 0 var internalIndex = -1 @@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable { internalIndex -= 1 val leftIndex = leftChildIndex(rawIndex) val rightIndex = rightChildIndex(rawIndex) - val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) - val height = math.sqrt(indexes.map { childIndex => - EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) - }.max) - val children = indexes.map(buildSubTree(_)).toArray + val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains) + val height = indexes.map { childIndex => + distanceMeasure.distance(center, clusters(childIndex).center) + }.max + val children = indexes.map(buildSubTree).toArray new ClusteringTreeNode(index, size, center, cost, height, children) } else { val index = leafIndex @@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] ( def center: Vector = centerWithNorm.vector /** Predicts the leaf cluster node index that the input point belongs to. */ - def predict(point: Vector): Int = { - val (index, _) = predict(new VectorWithNorm(point)) + def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = { + val (index, _) = predict(new VectorWithNorm(point), distanceMeasure) index } /** Returns the full prediction path from root to leaf. */ - def predictPath(point: Vector): Array[ClusteringTreeNode] = { - predictPath(new VectorWithNorm(point)).toArray + def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point), distanceMeasure).toArray } /** Returns the full prediction path from root to leaf. */ - private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + private def predictPath( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = { if (isLeaf) { this :: Nil } else { val selected = children.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + distanceMeasure.distance(child.centerWithNorm, pointWithNorm) } - selected :: selected.predictPath(pointWithNorm) + selected :: selected.predictPath(pointWithNorm, distanceMeasure) } } /** - * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + * Computes the cost of the input point. */ - def computeCost(point: Vector): Double = { - val (_, cost) = predict(new VectorWithNorm(point)) + def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = { + val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure) cost } /** * Predicts the cluster index and the cost of the input point. */ - private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, - EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) + private def predict( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): (Int, Double) = { + predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure) } /** @@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * @return (predicted leaf cluster index, cost) */ @tailrec - private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + private def predict( + pointWithNorm: VectorWithNorm, + cost: Double, + distanceMeasure: DistanceMeasure): (Int, Double) = { if (isLeaf) { (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) - selectedChild.predict(pointWithNorm, minCost) + selectedChild.predict(pointWithNorm, minCost, distanceMeasure) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 633bda6aac804..9d115afcea75d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession} */ @Since("1.6.0") class BisectingKMeansModel private[clustering] ( - private[clustering] val root: ClusteringTreeNode + private[clustering] val root: ClusteringTreeNode, + @Since("2.4.0") val distanceMeasure: String ) extends Serializable with Saveable with Logging { + @Since("1.6.0") + def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) + + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + /** * Leaf cluster centers. */ @@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(point: Vector): Int = { - root.predict(point) + root.predict(point, distanceMeasureInstance) } /** @@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(points: RDD[Vector]): RDD[Int] = { - points.map { p => root.predict(p) } + points.map { p => root.predict(p, distanceMeasureInstance) } } /** @@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(point: Vector): Double = { - root.computeCost(point) + root.computeCost(point, distanceMeasureInstance) } /** @@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: RDD[Vector]): Double = { - data.map(root.computeCost).sum() + data.map(root.computeCost(_, distanceMeasureInstance)).sum() } /** @@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { @Since("2.0.0") override def load(sc: SparkContext, path: String): BisectingKMeansModel = { - val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) - implicit val formats = DefaultFormats - val rootId = (metadata \ "rootId").extract[Int] - val classNameV1_0 = SaveLoadV1_0.thisClassName + val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path) (loadedClassName, formatVersion) match { - case (classNameV1_0, "1.0") => - val model = SaveLoadV1_0.load(sc, path, rootId) + case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) + model + case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $formatVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + + s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") } } @@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) } + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes) ++ Array(node) + } + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes(rootId) + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + private[clustering] object SaveLoadV1_0 { - private val thisFormatVersion = "1.0" + private[clustering] val thisFormatVersion = "1.0" private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" @@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) } - private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { - if (node.children.isEmpty) { - Array(node) - } else { - node.children.flatMap(getNodes(_)) ++ Array(node) - } - } - - def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode) + new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) } + } + + private[clustering] object SaveLoadV2_0 { + private[clustering] val thisFormatVersion = "2.0" - private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { - val root = nodes.get(rootId).get - if (root.children.isEmpty) { - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, new Array[ClusteringTreeNode](0)) - } else { - val children = root.children.map(c => buildTree(c, nodes)) - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, children.toArray) - } + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode, distanceMeasure) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala new file mode 100644 index 0000000000000..683360efabc76 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} +import org.apache.spark.mllib.util.MLUtils + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return the K-means cost of a given point against the given cluster centers. + */ + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + + /** + * @return the total cost of the cluster from its aggregated properties + */ + def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val left = centroid.copy + axpy(-level, noise, left) + val right = centroid.copy + axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = distance(point, centroid) +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary + // distance computation. + var lowerBoundOfSqDist = center.norm - point.norm + lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist + if (lowerBoundOfSqDist < bestDistance) { + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + override def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = { + EuclideanDistanceMeasure.fastSquaredDistance(point, centroid) + } +} + + +private[spark] object EuclideanDistanceMeasure { + /** + * @return the squared Euclidean distance between two vectors computed by + * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. + */ + private[clustering] def fastSquaredDistance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double = { + MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) + } +} + +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.") + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + val costVector = pointsSum.vector.copy + math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + override def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val (left, right) = super.symmetricCentroids(level, noise, centroid) + val leftVector = left.vector + val rightVector = right.vector + scal(1.0 / left.norm, leftVector) + scal(1.0 / right.norm, rightVector) + (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 3c4ba0bc60c7f..b5b1be3490497 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -204,7 +203,7 @@ class KMeans private ( */ @Since("2.4.0") def setDistanceMeasure(distanceMeasure: String): this.type = { - KMeans.validateDistanceMeasure(distanceMeasure) + DistanceMeasure.validateDistanceMeasure(distanceMeasure) this.distanceMeasure = distanceMeasure this } @@ -582,14 +581,6 @@ object KMeans { case _ => false } } - - private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { - distanceMeasure match { - case DistanceMeasure.EUCLIDEAN => true - case DistanceMeasure.COSINE => true - case _ => false - } - } } /** @@ -605,186 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) /** Converts the vector to a dense vector. */ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } - - -private[spark] abstract class DistanceMeasure extends Serializable { - - /** - * @return the index of the closest center to the given point, as well as the cost. - */ - def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - val currentDistance = distance(center, point) - if (currentDistance < bestDistance) { - bestDistance = currentDistance - bestIndex = i - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return the K-means cost of a given point against the given cluster centers. - */ - def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = { - findClosest(centers, point)._2 - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - distance(oldCenter, newCenter) <= epsilon - } - - /** - * @return the cosine distance between two points. - */ - def distance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - new VectorWithNorm(sum) - } -} - -@Since("2.4.0") -object DistanceMeasure { - - @Since("2.4.0") - val EUCLIDEAN = "euclidean" - @Since("2.4.0") - val COSINE = "cosine" - - private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = - distanceMeasure match { - case EUCLIDEAN => new EuclideanDistanceMeasure - case COSINE => new CosineDistanceMeasure - case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + - s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") - } -} - -private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { - /** - * @return the index of the closest center to the given point, as well as the squared distance. - */ - override def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary - // distance computation. - var lowerBoundOfSqDist = center.norm - point.norm - lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist - if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - override def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon - } - - /** - * @param v1: first vector - * @param v2: second vector - * @return the Euclidean distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) - } -} - - -private[spark] object EuclideanDistanceMeasure { - /** - * @return the squared Euclidean distance between two vectors computed by - * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. - */ - private[clustering] def fastSquaredDistance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double = { - MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) - } -} - -private[spark] class CosineDistanceMeasure extends DistanceMeasure { - /** - * @param v1: first vector - * @param v2: second vector - * @return the cosine distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") - 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm - } - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0 / point.norm, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - override def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - val norm = Vectors.norm(sum, 2) - scal(1.0 / norm, sum) - new VectorWithNorm(sum, 1) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fa7471fa2d658..02880f96ae6d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.clustering -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -140,6 +142,46 @@ class BisectingKMeansSuite testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, BisectingKMeansSuite.allParamSettings, checkModelData) } + + test("BisectingKMeans with cosine distance is not supported for 0-length vectors") { + val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) + } + + test("BisectingKMeans with cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + val model = new BisectingKMeans() + .setK(3) + .setDistanceMeasure(DistanceMeasure.COSINE) + .setSeed(1) + .fit(df) + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } } object BisectingKMeansSuite { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 381f7b5be1ddf..1b6d1dec69d49 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,12 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), From 23370554d0f88b82154d4232744b874cc58c7848 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 15:20:09 +0100 Subject: [PATCH 157/593] [SPARK-23656][TEST] Perform assertions in XXH64Suite.testKnownByteArrayInputs() on big endian platform, too ## What changes were proposed in this pull request? This PR enables assertions in `XXH64Suite.testKnownByteArrayInputs()` on big endian platform, too. The current implementation performs them only on little endian platform. This PR increase test coverage of big endian platform. ## How was this patch tested? Updated `XXH64Suite` Tested on big endian platform using JIT compiler or interpreter `-Xint`. Author: Kazuaki Ishizaki Closes #20804 from kiszk/SPARK-23656. --- .../sql/catalyst/expressions/XXH64Suite.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 711887f02832a..1baee91b3439c 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -74,9 +74,6 @@ public void testKnownByteArrayInputs() { Assert.assertEquals(0x739840CB819FA723L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); - // These tests currently fail in a big endian environment because the test data and expected - // answers are generated with little endian the assumptions. We could revisit this when Platform - // becomes endian aware. if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); @@ -94,6 +91,23 @@ public void testKnownByteArrayInputs() { hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); Assert.assertEquals(0xCAA65939306F1E21L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } else { + Assert.assertEquals(0x7F875412350ADDDCL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x564D279F524D8516L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0x7D9F07E27E0EB006L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x893CEF564CB7858L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xC6198C4C9CC49E17L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x4E21BEF7164D4BBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0xBCF5FAEDEE1F2B5AL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0x6F680C877A358FE5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); } } From 9ddd1e2ceac8155b30beebb6bbfdcd32296fab2d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 13 Mar 2018 23:31:08 +0900 Subject: [PATCH 158/593] [MINOR][SQL][TEST] Create table using `dataSourceName` in `HadoopFsRelationTest` ## What changes were proposed in this pull request? This PR fixes a minor issue in `HadoopFsRelationTest`, that you should create table using `dataSourceName` instead of `parquet`. The issue won't affect the correctness, but it will generate wrong error message in case the test fails. ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #20780 from jiangxb1987/dataSourceName. --- .../apache/spark/sql/sources/HadoopFsRelationTest.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 80aff446bc24b..53397991e59dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -335,16 +335,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") - intercept[AnalysisException] { + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") + val msg = intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") - } + }.getMessage + assert(msg.contains("Table `t` already exists")) } } test("saveAsTable()/load() - non-partitioned table - Ignore") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } From 918fb9beee6a2fd499b8f18dfe0d460f078f5290 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Tue, 13 Mar 2018 11:31:32 -0700 Subject: [PATCH 159/593] [SPARK-23547][SQL] Cleanup the .pipeout file when the Hive Session closed ## What changes were proposed in this pull request? ![2018-03-07_121010](https://user-images.githubusercontent.com/24823338/37073232-922e10d2-2200-11e8-8172-6e03aa984b39.png) when the hive session closed, we should also cleanup the .pipeout file. ## How was this patch tested? Added test cases. Author: zuotingbing Closes #20702 from zuotingbing/SPARK-23547. --- .../service/cli/session/HiveSessionImpl.java | 18 +++++++++++ .../HiveThriftServer2Suites.scala | 32 ++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index fc818bc69c761..f59cdcd3188e6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -641,6 +641,8 @@ public void close() throws HiveSQLException { opHandleSet.clear(); // Cleanup session log directory. cleanupSessionLogDir(); + // Cleanup pipeout file. + cleanupPipeoutFile(); HiveHistory hiveHist = sessionState.getHiveHistory(); if (null != hiveHist) { hiveHist.closeStream(); @@ -665,6 +667,22 @@ public void close() throws HiveSQLException { } } + private void cleanupPipeoutFile() { + String lScratchDir = hiveConf.getVar(ConfVars.LOCALSCRATCHDIR); + String sessionID = hiveConf.getVar(ConfVars.HIVESESSIONID); + + File[] fileAry = new File(lScratchDir).listFiles( + (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); + + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } + } + } + private void cleanupSessionLogDir() { if (isOperationLogEnabled) { try { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b32c547cefefe..192f33a45e273 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import java.io.File +import java.io.{File, FilenameFilter} import java.net.URL import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -613,6 +614,28 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { bufferSrc.close() } } + + test("SPARK-23547 Cleanup the .pipeout file when the Hive Session closed") { + def pipeoutFileList(sessionID: UUID): Array[File] = { + lScratchDir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + name.startsWith(sessionID.toString) && name.endsWith(".pipeout") + } + }) + } + + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + val sessionID = sessionHandle.getSessionId + + assert(pipeoutFileList(sessionID).length == 1) + + client.closeSession(sessionHandle) + + assert(pipeoutFileList(sessionID).length == 0) + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -807,6 +830,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ + protected var lScratchDir: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -844,6 +868,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath + | --hiveconf ${ConfVars.LOCALSCRATCHDIR}=$lScratchDir | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -873,6 +898,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() operationLogPath = Utils.createTempDir() operationLogPath.delete() + lScratchDir = Utils.createTempDir() + lScratchDir.delete() logPath = null logTailingProcess = null @@ -956,6 +983,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl operationLogPath.delete() operationLogPath = null + lScratchDir.delete() + lScratchDir = null + Option(logPath).foreach(_.delete()) logPath = null From 1098933b0ac5cdb18101d3aebefa773c2ce05a50 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 23:04:16 +0100 Subject: [PATCH 160/593] [SPARK-23598][SQL] Make methods in BufferedRowIterator public to avoid runtime error for a large query ## What changes were proposed in this pull request? This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`. This PR fixes this issue by making them `public`. Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`. ``` test("SPARK-23598") { // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") df_pet_age.groupBy("name").avg("age").show() } ``` Exception: ``` 19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1 at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ... ``` Generated code (line 195 calles `stopEarly()`). ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean agg_initAgg; /* 010 */ private boolean agg_bufIsNull; /* 011 */ private double agg_bufValue; /* 012 */ private boolean agg_bufIsNull1; /* 013 */ private long agg_bufValue1; /* 014 */ private agg_FastHashMap agg_fastHashMap; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_fastHashMapIter; /* 016 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 017 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 018 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 019 */ private scala.collection.Iterator inputadapter_input; /* 020 */ private boolean agg_agg_isNull11; /* 021 */ private boolean agg_agg_isNull25; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2]; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 024 */ private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2]; /* 025 */ /* 026 */ public GeneratedIteratorForCodegenStage1(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ /* 034 */ agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer()); /* 035 */ agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap(); /* 036 */ inputadapter_input = inputs[0]; /* 037 */ agg_mutableStateArray[0] = new UnsafeRow(1); /* 038 */ agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32); /* 039 */ agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1); /* 040 */ agg_mutableStateArray[1] = new UnsafeRow(3); /* 041 */ agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32); /* 042 */ agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ public class agg_FastHashMap { /* 047 */ private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; /* 048 */ private int[] buckets; /* 049 */ private int capacity = 1 << 16; /* 050 */ private double loadFactor = 0.5; /* 051 */ private int numBuckets = (int) (capacity / loadFactor); /* 052 */ private int maxSteps = 2; /* 053 */ private int numRows = 0; /* 054 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType); /* 055 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType) /* 056 */ .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType); /* 057 */ private Object emptyVBase; /* 058 */ private long emptyVOff; /* 059 */ private int emptyVLen; /* 060 */ private boolean isBatchFull = false; /* 061 */ /* 062 */ public agg_FastHashMap( /* 063 */ org.apache.spark.memory.TaskMemoryManager taskMemoryManager, /* 064 */ InternalRow emptyAggregationBuffer) { /* 065 */ batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch /* 066 */ .allocate(keySchema, valueSchema, taskMemoryManager, capacity); /* 067 */ /* 068 */ final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); /* 069 */ final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); /* 070 */ /* 071 */ emptyVBase = emptyBuffer; /* 072 */ emptyVOff = Platform.BYTE_ARRAY_OFFSET; /* 073 */ emptyVLen = emptyBuffer.length; /* 074 */ /* 075 */ buckets = new int[numBuckets]; /* 076 */ java.util.Arrays.fill(buckets, -1); /* 077 */ } /* 078 */ /* 079 */ public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) { /* 080 */ long h = hash(agg_key); /* 081 */ int step = 0; /* 082 */ int idx = (int) h & (numBuckets - 1); /* 083 */ while (step < maxSteps) { /* 084 */ // Return bucket index if it's either an empty slot or already contains the key /* 085 */ if (buckets[idx] == -1) { /* 086 */ if (numRows < capacity && !isBatchFull) { /* 087 */ // creating the unsafe for new entry /* 088 */ UnsafeRow agg_result = new UnsafeRow(1); /* 089 */ org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder /* 090 */ = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, /* 091 */ 32); /* 092 */ org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter /* 093 */ = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( /* 094 */ agg_holder, /* 095 */ 1); /* 096 */ agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed /* 097 */ agg_rowWriter.zeroOutNullBytes(); /* 098 */ agg_rowWriter.write(0, agg_key); /* 099 */ agg_result.setTotalSize(agg_holder.totalSize()); /* 100 */ Object kbase = agg_result.getBaseObject(); /* 101 */ long koff = agg_result.getBaseOffset(); /* 102 */ int klen = agg_result.getSizeInBytes(); /* 103 */ /* 104 */ UnsafeRow vRow /* 105 */ = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen); /* 106 */ if (vRow == null) { /* 107 */ isBatchFull = true; /* 108 */ } else { /* 109 */ buckets[idx] = numRows++; /* 110 */ } /* 111 */ return vRow; /* 112 */ } else { /* 113 */ // No more space /* 114 */ return null; /* 115 */ } /* 116 */ } else if (equals(idx, agg_key)) { /* 117 */ return batch.getValueRow(buckets[idx]); /* 118 */ } /* 119 */ idx = (idx + 1) & (numBuckets - 1); /* 120 */ step++; /* 121 */ } /* 122 */ // Didn't find it /* 123 */ return null; /* 124 */ } /* 125 */ /* 126 */ private boolean equals(int idx, UTF8String agg_key) { /* 127 */ UnsafeRow row = batch.getKeyRow(buckets[idx]); /* 128 */ return (row.getUTF8String(0).equals(agg_key)); /* 129 */ } /* 130 */ /* 131 */ private long hash(UTF8String agg_key) { /* 132 */ long agg_hash = 0; /* 133 */ /* 134 */ int agg_result = 0; /* 135 */ byte[] agg_bytes = agg_key.getBytes(); /* 136 */ for (int i = 0; i < agg_bytes.length; i++) { /* 137 */ int agg_hash1 = agg_bytes[i]; /* 138 */ agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2); /* 139 */ } /* 140 */ /* 141 */ agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2); /* 142 */ /* 143 */ return agg_hash; /* 144 */ } /* 145 */ /* 146 */ public org.apache.spark.unsafe.KVIterator rowIterator() { /* 147 */ return batch.rowIterator(); /* 148 */ } /* 149 */ /* 150 */ public void close() { /* 151 */ batch.close(); /* 152 */ } /* 153 */ /* 154 */ } /* 155 */ /* 156 */ protected void processNext() throws java.io.IOException { /* 157 */ if (!agg_initAgg) { /* 158 */ agg_initAgg = true; /* 159 */ long wholestagecodegen_beforeAgg = System.nanoTime(); /* 160 */ agg_nestedClassInstance1.agg_doAggregateWithKeys(); /* 161 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000); /* 162 */ } /* 163 */ /* 164 */ // output the result /* 165 */ /* 166 */ while (agg_fastHashMapIter.next()) { /* 167 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey(); /* 168 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue(); /* 169 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 170 */ /* 171 */ if (shouldStop()) return; /* 172 */ } /* 173 */ agg_fastHashMap.close(); /* 174 */ /* 175 */ while (agg_mapIter.next()) { /* 176 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 177 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 178 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 179 */ /* 180 */ if (shouldStop()) return; /* 181 */ } /* 182 */ /* 183 */ agg_mapIter.close(); /* 184 */ if (agg_sorter == null) { /* 185 */ agg_hashMap.free(); /* 186 */ } /* 187 */ } /* 188 */ /* 189 */ private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass(); /* 190 */ private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1(); /* 191 */ private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass(); /* 192 */ /* 193 */ private class agg_NestedClass1 { /* 194 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 195 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 196 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 197 */ int inputadapter_value = inputadapter_row.getInt(0); /* 198 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 199 */ UTF8String inputadapter_value1 = inputadapter_isNull1 ? /* 200 */ null : (inputadapter_row.getUTF8String(1)); /* 201 */ /* 202 */ agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1); /* 203 */ if (shouldStop()) return; /* 204 */ } /* 205 */ /* 206 */ agg_fastHashMapIter = agg_fastHashMap.rowIterator(); /* 207 */ agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */)); /* 208 */ /* 209 */ } /* 210 */ /* 211 */ } /* 212 */ /* 213 */ private class wholestagecodegen_NestedClass { /* 214 */ private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm) /* 215 */ throws java.io.IOException { /* 216 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1); /* 217 */ /* 218 */ boolean agg_isNull35 = agg_keyTerm.isNullAt(0); /* 219 */ UTF8String agg_value37 = agg_isNull35 ? /* 220 */ null : (agg_keyTerm.getUTF8String(0)); /* 221 */ boolean agg_isNull36 = agg_bufferTerm.isNullAt(0); /* 222 */ double agg_value38 = agg_isNull36 ? /* 223 */ -1.0 : (agg_bufferTerm.getDouble(0)); /* 224 */ boolean agg_isNull37 = agg_bufferTerm.isNullAt(1); /* 225 */ long agg_value39 = agg_isNull37 ? /* 226 */ -1L : (agg_bufferTerm.getLong(1)); /* 227 */ /* 228 */ agg_mutableStateArray1[1].reset(); /* 229 */ /* 230 */ agg_mutableStateArray2[1].zeroOutNullBytes(); /* 231 */ /* 232 */ if (agg_isNull35) { /* 233 */ agg_mutableStateArray2[1].setNullAt(0); /* 234 */ } else { /* 235 */ agg_mutableStateArray2[1].write(0, agg_value37); /* 236 */ } /* 237 */ /* 238 */ if (agg_isNull36) { /* 239 */ agg_mutableStateArray2[1].setNullAt(1); /* 240 */ } else { /* 241 */ agg_mutableStateArray2[1].write(1, agg_value38); /* 242 */ } /* 243 */ /* 244 */ if (agg_isNull37) { /* 245 */ agg_mutableStateArray2[1].setNullAt(2); /* 246 */ } else { /* 247 */ agg_mutableStateArray2[1].write(2, agg_value39); /* 248 */ } /* 249 */ agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize()); /* 250 */ append(agg_mutableStateArray[1]); /* 251 */ /* 252 */ } /* 253 */ /* 254 */ } /* 255 */ /* 256 */ private class agg_NestedClass { /* 257 */ private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException { /* 258 */ UnsafeRow agg_unsafeRowAggBuffer = null; /* 259 */ UnsafeRow agg_fastAggBuffer = null; /* 260 */ /* 261 */ if (true) { /* 262 */ if (!agg_exprIsNull_1) { /* 263 */ agg_fastAggBuffer = agg_fastHashMap.findOrInsert( /* 264 */ agg_expr_1); /* 265 */ } /* 266 */ } /* 267 */ // Cannot find the key in fast hash map, try regular hash map. /* 268 */ if (agg_fastAggBuffer == null) { /* 269 */ // generate grouping key /* 270 */ agg_mutableStateArray1[0].reset(); /* 271 */ /* 272 */ agg_mutableStateArray2[0].zeroOutNullBytes(); /* 273 */ /* 274 */ if (agg_exprIsNull_1) { /* 275 */ agg_mutableStateArray2[0].setNullAt(0); /* 276 */ } else { /* 277 */ agg_mutableStateArray2[0].write(0, agg_expr_1); /* 278 */ } /* 279 */ agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize()); /* 280 */ int agg_value7 = 42; /* 281 */ /* 282 */ if (!agg_exprIsNull_1) { /* 283 */ agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7); /* 284 */ } /* 285 */ if (true) { /* 286 */ // try to get the buffer from hash map /* 287 */ agg_unsafeRowAggBuffer = /* 288 */ agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7); /* 289 */ } /* 290 */ // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based /* 291 */ // aggregation after processing all input rows. /* 292 */ if (agg_unsafeRowAggBuffer == null) { /* 293 */ if (agg_sorter == null) { /* 294 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 295 */ } else { /* 296 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 297 */ } /* 298 */ /* 299 */ // the hash map had be spilled, it should have enough memory now, /* 300 */ // try to allocate buffer again. /* 301 */ agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow( /* 302 */ agg_mutableStateArray[0], agg_value7); /* 303 */ if (agg_unsafeRowAggBuffer == null) { /* 304 */ // failed to allocate the first page /* 305 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 306 */ } /* 307 */ } /* 308 */ /* 309 */ } /* 310 */ /* 311 */ if (agg_fastAggBuffer != null) { /* 312 */ // common sub-expressions /* 313 */ boolean agg_isNull21 = false; /* 314 */ long agg_value23 = -1L; /* 315 */ if (!false) { /* 316 */ agg_value23 = (long) agg_expr_0; /* 317 */ } /* 318 */ // evaluate aggregate function /* 319 */ boolean agg_isNull23 = true; /* 320 */ double agg_value25 = -1.0; /* 321 */ /* 322 */ boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0); /* 323 */ double agg_value26 = agg_isNull24 ? /* 324 */ -1.0 : (agg_fastAggBuffer.getDouble(0)); /* 325 */ if (!agg_isNull24) { /* 326 */ agg_agg_isNull25 = true; /* 327 */ double agg_value27 = -1.0; /* 328 */ do { /* 329 */ boolean agg_isNull26 = agg_isNull21; /* 330 */ double agg_value28 = -1.0; /* 331 */ if (!agg_isNull21) { /* 332 */ agg_value28 = (double) agg_value23; /* 333 */ } /* 334 */ if (!agg_isNull26) { /* 335 */ agg_agg_isNull25 = false; /* 336 */ agg_value27 = agg_value28; /* 337 */ continue; /* 338 */ } /* 339 */ /* 340 */ boolean agg_isNull27 = false; /* 341 */ double agg_value29 = -1.0; /* 342 */ if (!false) { /* 343 */ agg_value29 = (double) 0; /* 344 */ } /* 345 */ if (!agg_isNull27) { /* 346 */ agg_agg_isNull25 = false; /* 347 */ agg_value27 = agg_value29; /* 348 */ continue; /* 349 */ } /* 350 */ /* 351 */ } while (false); /* 352 */ /* 353 */ agg_isNull23 = false; // resultCode could change nullability. /* 354 */ agg_value25 = agg_value26 + agg_value27; /* 355 */ /* 356 */ } /* 357 */ boolean agg_isNull29 = false; /* 358 */ long agg_value31 = -1L; /* 359 */ if (!false && agg_isNull21) { /* 360 */ boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1); /* 361 */ long agg_value33 = agg_isNull31 ? /* 362 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 363 */ agg_isNull29 = agg_isNull31; /* 364 */ agg_value31 = agg_value33; /* 365 */ } else { /* 366 */ boolean agg_isNull32 = true; /* 367 */ long agg_value34 = -1L; /* 368 */ /* 369 */ boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1); /* 370 */ long agg_value35 = agg_isNull33 ? /* 371 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 372 */ if (!agg_isNull33) { /* 373 */ agg_isNull32 = false; // resultCode could change nullability. /* 374 */ agg_value34 = agg_value35 + 1L; /* 375 */ /* 376 */ } /* 377 */ agg_isNull29 = agg_isNull32; /* 378 */ agg_value31 = agg_value34; /* 379 */ } /* 380 */ // update fast row /* 381 */ if (!agg_isNull23) { /* 382 */ agg_fastAggBuffer.setDouble(0, agg_value25); /* 383 */ } else { /* 384 */ agg_fastAggBuffer.setNullAt(0); /* 385 */ } /* 386 */ /* 387 */ if (!agg_isNull29) { /* 388 */ agg_fastAggBuffer.setLong(1, agg_value31); /* 389 */ } else { /* 390 */ agg_fastAggBuffer.setNullAt(1); /* 391 */ } /* 392 */ } else { /* 393 */ // common sub-expressions /* 394 */ boolean agg_isNull7 = false; /* 395 */ long agg_value9 = -1L; /* 396 */ if (!false) { /* 397 */ agg_value9 = (long) agg_expr_0; /* 398 */ } /* 399 */ // evaluate aggregate function /* 400 */ boolean agg_isNull9 = true; /* 401 */ double agg_value11 = -1.0; /* 402 */ /* 403 */ boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0); /* 404 */ double agg_value12 = agg_isNull10 ? /* 405 */ -1.0 : (agg_unsafeRowAggBuffer.getDouble(0)); /* 406 */ if (!agg_isNull10) { /* 407 */ agg_agg_isNull11 = true; /* 408 */ double agg_value13 = -1.0; /* 409 */ do { /* 410 */ boolean agg_isNull12 = agg_isNull7; /* 411 */ double agg_value14 = -1.0; /* 412 */ if (!agg_isNull7) { /* 413 */ agg_value14 = (double) agg_value9; /* 414 */ } /* 415 */ if (!agg_isNull12) { /* 416 */ agg_agg_isNull11 = false; /* 417 */ agg_value13 = agg_value14; /* 418 */ continue; /* 419 */ } /* 420 */ /* 421 */ boolean agg_isNull13 = false; /* 422 */ double agg_value15 = -1.0; /* 423 */ if (!false) { /* 424 */ agg_value15 = (double) 0; /* 425 */ } /* 426 */ if (!agg_isNull13) { /* 427 */ agg_agg_isNull11 = false; /* 428 */ agg_value13 = agg_value15; /* 429 */ continue; /* 430 */ } /* 431 */ /* 432 */ } while (false); /* 433 */ /* 434 */ agg_isNull9 = false; // resultCode could change nullability. /* 435 */ agg_value11 = agg_value12 + agg_value13; /* 436 */ /* 437 */ } /* 438 */ boolean agg_isNull15 = false; /* 439 */ long agg_value17 = -1L; /* 440 */ if (!false && agg_isNull7) { /* 441 */ boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1); /* 442 */ long agg_value19 = agg_isNull17 ? /* 443 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 444 */ agg_isNull15 = agg_isNull17; /* 445 */ agg_value17 = agg_value19; /* 446 */ } else { /* 447 */ boolean agg_isNull18 = true; /* 448 */ long agg_value20 = -1L; /* 449 */ /* 450 */ boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1); /* 451 */ long agg_value21 = agg_isNull19 ? /* 452 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 453 */ if (!agg_isNull19) { /* 454 */ agg_isNull18 = false; // resultCode could change nullability. /* 455 */ agg_value20 = agg_value21 + 1L; /* 456 */ /* 457 */ } /* 458 */ agg_isNull15 = agg_isNull18; /* 459 */ agg_value17 = agg_value20; /* 460 */ } /* 461 */ // update unsafe row buffer /* 462 */ if (!agg_isNull9) { /* 463 */ agg_unsafeRowAggBuffer.setDouble(0, agg_value11); /* 464 */ } else { /* 465 */ agg_unsafeRowAggBuffer.setNullAt(0); /* 466 */ } /* 467 */ /* 468 */ if (!agg_isNull15) { /* 469 */ agg_unsafeRowAggBuffer.setLong(1, agg_value17); /* 470 */ } else { /* 471 */ agg_unsafeRowAggBuffer.setNullAt(1); /* 472 */ } /* 473 */ /* 474 */ } /* 475 */ /* 476 */ } /* 477 */ /* 478 */ } /* 479 */ /* 480 */ } ``` ## How was this patch tested? Added UT into `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #20779 from kiszk/SPARK-23598. --- .../spark/sql/execution/BufferedRowIterator.java | 12 ++++++++---- .../spark/sql/execution/WholeStageCodegenSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 730a4ae8d5605..74c9c05992719 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -62,10 +62,14 @@ public long durationMs() { */ public abstract void init(int index, Iterator[] iters); + /* + * Attributes of the following four methods are public. Thus, they can be also accessed from + * methods in inner classes. See SPARK-23598 + */ /** * Append a row to currentRows. */ - protected void append(InternalRow row) { + public void append(InternalRow row) { currentRows.add(row); } @@ -75,7 +79,7 @@ protected void append(InternalRow row) { * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. * This interface is mainly used to limit the number of input rows. */ - protected boolean stopEarly() { + public boolean stopEarly() { return false; } @@ -84,14 +88,14 @@ protected boolean stopEarly() { * * If it returns true, the caller should exit the loop (return from processNext()). */ - protected boolean shouldStop() { + public boolean shouldStop() { return !currentRows.isEmpty(); } /** * Increase the peak execution memory for current task. */ - protected void incPeakExecutionMemory(long size) { + public void incPeakExecutionMemory(long size) { TaskContext.get().taskMetrics().incPeakExecutionMemory(size); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0fb9dd2017a09..4b40e4ef7571c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan @@ -307,4 +309,14 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { // a different query can result in codegen cache miss, that's by design } } + + test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") + for (i <- 0 until 70) { + df = df.groupBy("name").agg(avg("age").alias("age")) + } + assert(df.limit(1).collect() === Array(Row("bat", 8.0))) + } + } } From 279b3db8970809104c30941254e57e3d62da5041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Mar 2018 18:36:31 -0700 Subject: [PATCH 161/593] [SPARK-22915][MLLIB] Streaming tests for spark.ml.feature, from N to Z MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: - NGramSuite - NormalizerSuite - OneHotEncoderEstimatorSuite - OneHotEncoderSuite - PCASuite - PolynomialExpansionSuite - QuantileDiscretizerSuite - RFormulaSuite - SQLTransformerSuite - StandardScalerSuite - StopWordsRemoverSuite - StringIndexerSuite - TokenizerSuite - RegexTokenizerSuite - VectorAssemblerSuite - VectorIndexerSuite - VectorSizeHintSuite - VectorSlicerSuite - Word2VecSuite # How was this patch tested? They are unit test. Author: “attilapiros” Closes #20686 from attilapiros/SPARK-22915. --- .../apache/spark/ml/feature/NGramSuite.scala | 23 +- .../spark/ml/feature/NormalizerSuite.scala | 57 ++--- .../feature/OneHotEncoderEstimatorSuite.scala | 193 ++++++++--------- .../spark/ml/feature/OneHotEncoderSuite.scala | 124 ++++++----- .../apache/spark/ml/feature/PCASuite.scala | 14 +- .../ml/feature/PolynomialExpansionSuite.scala | 62 +++--- .../ml/feature/QuantileDiscretizerSuite.scala | 198 +++++++++-------- .../spark/ml/feature/RFormulaSuite.scala | 158 +++++++------- .../ml/feature/SQLTransformerSuite.scala | 35 +-- .../ml/feature/StandardScalerSuite.scala | 33 +-- .../ml/feature/StopWordsRemoverSuite.scala | 37 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 204 +++++++++--------- .../spark/ml/feature/TokenizerSuite.scala | 30 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 183 +++++++++------- .../ml/feature/VectorSizeHintSuite.scala | 88 +++++--- .../spark/ml/feature/VectorSlicerSuite.scala | 27 +-- .../spark/ml/feature/Word2VecSuite.scala | 28 +-- .../org/apache/spark/ml/util/MLTest.scala | 33 ++- 18 files changed, 809 insertions(+), 718 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index d4975c0b4e20e..e5956ee9942aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} + @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NGramSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.NGramSuite._ import testImplicits._ test("default behavior yields bigram features") { @@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setN(3) testDefaultReadWrite(t) } -} - -object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("nGrams", "wantedNGrams") - .collect() - .foreach { case Row(actualNGrams, wantedNGrams) => + def testNGram(t: NGram, dataFrame: DataFrame): Unit = { + testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { + case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => assert(actualNGrams === wantedNGrams) - } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index c75027fb4553d..eff57f1223af4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,21 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NormalizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @transient var data: Array[Vector] = _ - @transient var dataFrame: DataFrame = _ - @transient var normalizer: Normalizer = _ @transient var l1Normalized: Array[Vector] = _ @transient var l2Normalized: Array[Vector] = _ @@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.dense(0.897906166, 0.113419726, 0.42532397), Vectors.sparse(3, Seq()) ) - - dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() - normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normalized_features") - } - - def collectResult(result: DataFrame): Array[Vector] = { - result.select("normalized_features").collect().map { - case Row(features: Vector) => features - } } - def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { case (v1: DenseVector, v2: DenseVector) => true case (v1: SparseVector, v2: SparseVector) => true case _ => false }, "The vector type should be preserved after normalization.") } - def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { (vector1, vector2) => - vector1 ~== vector2 absTol 1E-5 - }, "The vector value is not correct after normalization.") + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.") } test("Normalization with default parameter") { - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized") + val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected") - assertValues(result, l2Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("Normalization with setter") { - normalizer.setP(1) + val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected") + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1) - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) - - assertValues(result, l1Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("read/write") { @@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } } - -private object NormalizerSuite { - case class FeatureData(features: Vector) -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 1d3f845586426..d549e13262273 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ -class OneHotEncoderEstimatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("size")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } test("input column without ML attribute") { @@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("index")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } } test("read/write") { @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite val df = spark.createDataFrame(sc.parallelize(data), schema) - val dfWithTypes = df - .withColumn("shortInput", df("input").cast(ShortType)) - .withColumn("longInput", df("input").cast(LongType)) - .withColumn("intInput", df("input").cast(IntegerType)) - .withColumn("floatInput", df("input").cast(FloatType)) - .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) - - val cols = Array("input", "shortInput", "longInput", "intInput", - "floatInput", "decimalInput") - for (col <- cols) { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array(col)) + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoderEstimator() + .setInputCols(Array("input")) .setOutputCols(Array("output")) .setDropLast(false) - val model = encoder.fit(dfWithTypes) - val encoded = model.transform(dfWithTypes) - - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) - } + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } @@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === false) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output1", "output2")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).show - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "encoded") + } test("Can't transform on negative input") { @@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Negative value: -1.0. Input can't be negative") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Negative value: -1.0. Input can't be negative", + firstResultCol = "encoded") } test("Keep on invalid values: dropLast = false") { @@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(false) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(true) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(df) model.setDropLast(false) - val encoded1 = model.transform(df) - encoded1.select("output", "expected1").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { + case Row(output: Vector, expected1: Vector) => + assert(output === expected1) } model.setDropLast(true) - val encoded2 = model.transform(df) - encoded2.select("output", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { + case Row(output: Vector, expected2: Vector) => + assert(output === expected2) } } @@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(trainingDF) model.setHandleInvalid("error") - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Double, Vector)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "output") model.setHandleInvalid("keep") - model.transform(testDF).collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } } test("Transforming on mismatched attributes") { @@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", testAttr.toMetadata())) - val err = intercept[Exception] { - model.transform(testDF).collect() - } - err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + testTransformerByInterceptingException[(Double)]( + testDF, + model, + expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", + firstResultCol = "encoded") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index c44c6813a94be..41b32b2ffa096 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ class OneHotEncoderSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -54,16 +54,19 @@ class OneHotEncoderSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("OneHotEncoder dropLast = true") { @@ -71,16 +74,19 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(2, Seq((0, 1.0)))), + (1, Vectors.sparse(2, Seq())), + (2, Vectors.sparse(2, Seq((1, 1.0)))), + (3, Vectors.sparse(2, Seq((0, 1.0)))), + (4, Vectors.sparse(2, Seq((0, 1.0)))), + (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("input column with ML attribute") { @@ -90,20 +96,22 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) + val rows = encoder.transform(df).select("encoded").collect() + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) @@ -119,29 +127,41 @@ class OneHotEncoderSuite test("OneHotEncoder with varying types") { val df = stringIndexed() - val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) - val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", - "floatLabel", "decimalLabel") - for (col <- cols) { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = df.join(expected, "id") + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = withExpected.select(col("labelIndex") + .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected")) val encoder = new OneHotEncoder() - .setInputCol(col) + .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) - val encoded = encoder.transform(dfWithTypes) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + + testTransformer(dfWithTypes, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 3067a52a4df76..531b1d7c4d9f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PCASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pcaModel = pca.fit(df) MLTestingUtils.checkCopyAndUids(pca, pcaModel) - - pcaModel.transform(df).select("pca_features", "expected").collect().foreach { - case Row(x: Vector, y: Vector) => - assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") { + case Row(result: Vector, expected: Vector) => + assert(result ~== expected absTol 1e-5, + "Transformed vector is different with expected vector.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index e4b0ddf98bfad..0be7aa6c83f29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PolynomialExpansionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,6 +55,18 @@ class PolynomialExpansionSuite -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), Vectors.sparse(19, Array.empty, Array.empty)) + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after polynomial expansion.") + } + + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.") + } + test("Polynomial expansion with default parameter") { val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") @@ -67,13 +74,10 @@ class PolynomialExpansionSuite .setInputCol("features") .setOutputCol("polyFeatures") - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -85,13 +89,10 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(3) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -103,11 +104,9 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(1) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { case Row(expanded: Vector, expected: Vector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + assertValues(expanded, expected) } } @@ -133,12 +132,13 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") for (i <- Seq(10, 11)) { - val transformed = t.setDegree(i) - .transform(df) - .select(s"expectedPoly${i}size", "polyFeatures") - .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } - - assert(transformed.collect.forall(identity)) + testTransformer[(Vector, Int, Int)]( + df, + t.setDegree(i), + s"expectedPoly${i}size", + "polyFeatures") { case Row(size: Int, expected: Vector) => + assert(size === expected.size) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6c363799dd300..b009038bbd833 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.udf -class QuantileDiscretizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,19 +36,19 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + val model = discretizer.fit(df) - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + val relativeError = discretizer.getRelativeError + val numGoodBuckets = result.groupBy("result").count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") } test("Test on data with high proportion of duplicated values") { @@ -67,11 +63,14 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") + } } test("Test transform on data with NaN value") { @@ -90,17 +89,20 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData.toSeq.toDF("input") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result") } List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result", "expected").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double)](dataFrame, model, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -119,14 +121,17 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(5) - val result = discretizer.fit(trainDF).transform(testDF) - val firstBucketSize = result.filter(result("result") === 0.0).count - val lastBucketSize = result.filter(result("result") === 4.0).count + val model = discretizer.fit(trainDF) + testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(firstBucketSize === 30L, - s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") - assert(lastBucketSize === 31L, - s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + } } test("read/write") { @@ -167,21 +172,24 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) - } - - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") - - val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + val relativeError = discretizer.getRelativeError + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result + .groupBy("result" + i) + .count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}") + .count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } } } @@ -198,12 +206,16 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets == expectedNumBucket, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBucket but found ($observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } } } @@ -226,9 +238,12 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double, Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result1") } List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { @@ -237,8 +252,14 @@ class QuantileDiscretizerSuite val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { case (((a, b), c), d) => (a, b, c, d) }.toSeq.toDF("input1", "input2", "expected1", "expected2") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result1", "expected1", "result2", "expected2").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double, Double, Double)]( + dataFrame, + model, + "result1", + "expected1", + "result2", + "expected2") { case Row(x: Double, y: Double, z: Double, w: Double) => assert(x === y && w === z) } @@ -270,9 +291,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(numBucketsArray) - discretizer.fit(df).transform(df). - select("result1", "expected1", "result2", "expected2", "result3", "expected3") - .collect().foreach { + val model = discretizer.fit(df) + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df, + model, + "result1", + "expected1", + "result2", + "expected2", + "result3", + "expected3") { case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => assert(r1 === e1, s"The result value is not correct after bucketing. Expected $e1 but found $r1") @@ -324,20 +352,16 @@ class QuantileDiscretizerSuite .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) .fit(df) - val resultForMultiCols = plForMultiCols.transform(df) - .select("result1", "result2", "result3") - .collect() - - val resultForSingleCol = plForSingleCol.transform(df) - .select("result1", "result2", "result3") - .collect() + val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect() - resultForSingleCol.zip(resultForMultiCols).foreach { - case (rowForSingle, rowForMultiCols) => - assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && - rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && - rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) - } + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + plForMultiCols, + "result1", + "result2", + "result3") { rows => + assert(rows === expected) + } } test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + @@ -364,18 +388,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(Array(10, 10, 10)) - val result1 = discretizerSingleNumBuckets.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - val result2 = discretizerNumBucketsArray.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - - result1.zip(result2).foreach { - case (row1, row2) => - assert(row1.getDouble(0) == row2.getDouble(0) && - row1.getDouble(1) == row2.getDouble(1) && - row1.getDouble(2) == row2.getDouble(2)) + val model = discretizerSingleNumBuckets.fit(df) + val expected = model.transform(df).select("result1", "result2", "result3").collect() + + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + discretizerNumBucketsArray.fit(df), + "result1", + "result2", + "result3") { rows => + assert(rows === expected) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index bfe38d32dd77d..27d570f0b68ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -32,10 +31,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { def testRFormulaTransform[A: Encoder]( dataframe: DataFrame, formulaModel: RFormulaModel, - expected: DataFrame): Unit = { + expected: DataFrame, + expectedAttributes: AttributeGroup*): Unit = { + val resultSchema = formulaModel.transformSchema(dataframe.schema) + assert(resultSchema.json === expected.schema.json) + assert(resultSchema === expected.schema) val (first +: rest) = expected.schema.fieldNames.toSeq val expectedRows = expected.collect() testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows.head.schema.toString() == resultSchema.toString()) + for (expectedAttributeGroup <- expectedAttributes) { + val attributeGroup = + AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name)) + assert(attributeGroup === expectedAttributeGroup) + } assert(rows === expectedRows) } } @@ -49,15 +58,10 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) MLTestingUtils.checkCopyAndUids(formula, model) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) ).toDF("id", "v1", "v2", "features", "label") - // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -73,9 +77,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "y", "features") val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == model.transform(original).schema.toString) + testRFormulaTransform[(Int, Double)](original, model, expected) } test("label column already exists but forceIndexLabel was set with true") { @@ -93,9 +101,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[IllegalArgumentException] { model.transformSchema(original.schema) } - intercept[IllegalArgumentException] { - model.transform(original) - } + testTransformerByInterceptingException[(Int, Boolean)]( + original, + model, + "Label column already exists and is not of type NumericType.", + "x") } test("allow missing label column for test datasets") { @@ -105,21 +115,22 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == model.transform(original).schema.toString) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "_not_y", "features") + testRFormulaTransform[(Int, Double)](original, model, expected) } test("allow empty label") { val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -128,15 +139,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -175,9 +183,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { val model = formula.setStringIndexerOrderType(orderType).fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } @@ -218,9 +223,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ).toDF("id", "a", "b", "features", "label") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -254,19 +256,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model1 = formula1.fit(original) - val result1 = model1.transform(original) - val resultSchema1 = model1.transformSchema(original.schema) - // Note the column order is different between R and Spark. - val expected1 = Seq( - (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), - (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), - (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), - (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) - ).toDF("id", "a", "b", "c", "features", "label") - assert(result1.schema.toString == resultSchema1.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) - - val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( "features", Array[Attribute]( @@ -275,14 +264,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new BinaryAttribute(Some("a_bar"), Some(3)), new BinaryAttribute(Some("b_zz"), Some(4)), new NumericAttribute(Some("c"), Some(5)))) - assert(attrs1 === expectedAttrs1) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1) // There is no impact for string terms interaction. val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model2 = formula2.fit(original) - val result2 = model2.transform(original) - val resultSchema2 = model2.transformSchema(original.schema) // Note the column order is different between R and Spark. val expected2 = Seq( (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), @@ -290,10 +285,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") - assert(result2.schema.toString == resultSchema2.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) - - val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( "features", Array[Attribute]( @@ -304,7 +295,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(5)), new NumericAttribute(Some("a_bar:b_zq"), Some(6)), new NumericAttribute(Some("c"), Some(7)))) - assert(attrs2 === expectedAttrs2) + + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2) } test("index string label") { @@ -313,13 +305,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) + val attr = NominalAttribute.defaultAttr val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") - // assert(result.schema.toString == resultSchema.toString) + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(String, String, Int)](original, model, expected) } @@ -329,13 +322,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val expected = spark.createDataFrame( - Seq( + val attr = NominalAttribute.defaultAttr + val expected = Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) - ).toDF("id", "a", "b", "features", "label") + .toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Double, String, Int)](original, model, expected) } @@ -344,15 +338,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + .toDF("id", "a", "b", "features", "label") val expectedAttrs = new AttributeGroup( "features", Array( new BinaryAttribute(Some("a_bar"), Some(1)), new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) + } test("vector attribute generation") { @@ -360,14 +359,19 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) .toDF("id", "vec") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val attrs = new AttributeGroup("vec", 2) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)) + .toDF("id", "vec", "features", "label") + .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec_0"), Some(1)), new NumericAttribute(Some("vec_1"), Some(2)))) - assert(attrs === expectedAttrs) + + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("vector attribute generation with unnamed input attrs") { @@ -381,31 +385,31 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0) + ).toDF("id", "vec2", "features", "label") + .select($"id", $"vec2".as("vec2", metadata), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec2_0"), Some(1)), new NumericAttribute(Some("vec2_1"), Some(2)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs) } test("factor numeric interaction") { @@ -414,7 +418,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -423,15 +426,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - testRFormulaTransform[(Int, String, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("a_baz:b"), Some(1)), new NumericAttribute(Some("a_bar:b"), Some(2)), new NumericAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) } test("factor factor interaction") { @@ -439,14 +440,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") testRFormulaTransform[(Int, String, String)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( @@ -454,7 +453,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(2)), new NumericAttribute(Some("a_foo:b_zq"), Some(3)), new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs) } test("read/write: RFormula") { @@ -517,9 +516,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen features. val formula1 = new RFormula().setFormula("id ~ a + b") - intercept[SparkException] { - formula1.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula1.fit(df1), + "Unseen label:", + "features") val model1 = formula1.setHandleInvalid("skip").fit(df1) val model2 = formula1.setHandleInvalid("keep").fit(df1) @@ -538,21 +539,28 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") - intercept[SparkException] { - formula2.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula2.fit(df1), + "Unseen label:", + "label") + val model3 = formula2.setHandleInvalid("skip").fit(df1) val model4 = formula2.setHandleInvalid("keep").fit(df1) + val attr = NominalAttribute.defaultAttr val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + val expected4 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Int, String, String)](df2, model3, expected3) testRFormulaTransform[(Int, String, String)](df2, model4, expected4) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 673a146e619f2..cf09418d8e0a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.storage.StorageLevel -class SQLTransformerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class SQLTransformerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -37,14 +34,22 @@ class SQLTransformerSuite val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") - val result = sqlTrans.transform(original) - val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) - assert(original.sparkSession.catalog.listTables().count() == 0) + val resultSchema = sqlTrans.transformSchema(original.schema) + testTransformerByGlobalCheckFunc[(Int, Double, Double)]( + original, + sqlTrans, + "id", + "v1", + "v2", + "v3", + "v4") { rows => + assert(rows.head.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(rows == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) + } } test("read/write") { @@ -63,13 +68,13 @@ class SQLTransformerSuite } test("SPARK-22538: SQLTransformer should not unpersist given dataset") { - val df = spark.range(10) + val df = spark.range(10).toDF() df.cache() df.count() assert(df.storageLevel != StorageLevel.NONE) - new SQLTransformer() + val sqlTrans = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") - .transform(df) + testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => } assert(df.storageLevel != StorageLevel.NONE) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 350ba44baa1eb..c5c49d67194e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class StandardScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext ) } - def assertResult(df: DataFrame): Unit = { - df.select("standardized_features", "expected").collect().foreach { - case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, - "The vector value is not correct after standardization.") - } + def assertResult: Row => Unit = { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") } test("params") { @@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext val standardScaler0 = standardScalerEst0.fit(df0) MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) - assertResult(standardScaler0.transform(df0)) + testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")( + assertResult) } test("Standardization with setter") { @@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithStd(false) .fit(df3) - assertResult(standardScaler1.transform(df1)) - assertResult(standardScaler2.transform(df2)) - assertResult(standardScaler3.transform(df3)) + testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")( + assertResult) } test("sparse data and withMean") { @@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithMean(true) .setWithStd(false) .fit(df) - assertResult(standardScaler.transform(df)) + testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")( + assertResult) } test("StandardScaler read/write") { @@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 5262b146b184e..21259a50916d2 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -17,28 +17,20 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - -object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("filtered", "expected") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} -class StopWordsRemoverSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { - import StopWordsRemoverSuite._ import testImplicits._ + def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = { + testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") { + case Row(tokens: Seq[_], wantedTokens: Seq[_]) => + assert(tokens === wantedTokens) + } + } + test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") @@ -151,9 +143,10 @@ class StopWordsRemoverSuite .setOutputCol(outputCol) val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) - val thrown = intercept[IllegalArgumentException] { - testStopWordsRemover(remover, dataSet) - } - assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + testTransformerByInterceptingException[(Array[String], Array[String])]( + dataSet, + remover, + s"requirement failed: Column $outputCol already exists.", + "expected") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 775a04d3df050..df24367177011 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -class StringIndexerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StringIndexerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -46,19 +43,23 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val indexerModel = indexer.fit(df) - MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - - val transformed = indexerModel.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + + testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(rows.seq === expected.collect().toSeq) + } } test("StringIndexerUnseen") { @@ -70,36 +71,38 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + // Verify we throw by default with unseen values - intercept[SparkException] { - indexer.transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer, + "Unseen label:", + "labelIndex") - indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformedSkip = indexer.transform(df2) - val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + indexer.setHandleInvalid("skip") + + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows.seq === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - // Verify that we keep the unseen records - val transformedKeep = indexer.transform(df2) - val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 - val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF() + + // Verify that we keep the unseen records + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexer with a numeric input column") { @@ -109,16 +112,14 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + assert(rows === expected.collect().toSeq) + } } test("StringIndexer with NULLs") { @@ -133,37 +134,36 @@ class StringIndexerSuite withClue("StringIndexer should throw error when setHandleInvalid=error " + "when given NULL values") { - intercept[SparkException] { - indexer.setHandleInvalid("error") - indexer.fit(df).transform(df2).collect() - } + indexer.setHandleInvalid("error") + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer.fit(df), + "StringIndexer encountered NULL value.", + "labelIndex") } indexer.setHandleInvalid("skip") - val transformedSkip = indexer.fit(df).transform(df2) - val attrSkip = Attribute - .fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + val modelSkip = indexer.fit(df) // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", "labelIndex") { rows => + val attrSkip = + Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - val transformedKeep = indexer.fit(df).transform(df2) - val attrKeep = Attribute - .fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0, null -> 2 - val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF() + val modelKeep = indexer.fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", "labelIndex") { rows => + val attrKeep = Attribute + .fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexerModel should keep silent if the input column does not exist.") { @@ -171,7 +171,9 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() - assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) + testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows => + assert(rows.toSet === df.collect().toSet) + } } test("StringIndexerModel can't overwrite output column") { @@ -188,9 +190,12 @@ class StringIndexerSuite .setOutputCol("indexedInput") .fit(df) - intercept[IllegalArgumentException] { - indexer.setOutputCol("output").transform(df) - } + testTransformerByInterceptingException[(Int, String)]( + df, + indexer.setOutputCol("output"), + "Output column output already exists.", + "labelIndex") + } test("StringIndexer read/write") { @@ -223,7 +228,8 @@ class StringIndexerSuite .setInputCol("index") .setOutputCol("actual") .setLabels(labels) - idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -234,7 +240,8 @@ class StringIndexerSuite val idxToStr1 = new IndexToString() .setInputCol("indexWithAttr") .setOutputCol("actual") - idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -252,9 +259,10 @@ class StringIndexerSuite .setInputCol("labelIndex") .setOutputCol("sameLabel") .setLabels(indexer.labels) - idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { - case Row(a: String, b: String) => - assert(a === b) + + testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", "label") { + case Row(sameLabel, label) => + assert(sameLabel === label) } } @@ -286,10 +294,11 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attrs = - NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) - assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") { rows => + val attrs = + NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } test("StringIndexer order types") { @@ -299,18 +308,17 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") - val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), - Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), - Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), - Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { - val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - assert(output === expected(idx)) + val model = indexer.setStringOrderType(orderType).fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", "labelIndex") { rows => + assert(rows === expected(idx).toDF().collect().toSeq) + } idx += 1 } } @@ -328,7 +336,11 @@ class StringIndexerSuite .setOutputCol("CITYIndexed") .fit(dfNoBristol) - val dfWithIndex = model.transform(dfNoBristol) - assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + testTransformerByGlobalCheckFunc[(String, String, String)]( + dfNoBristol, + model, + "CITYIndexed") { rows => + assert(rows.toList.count(_.getDouble(0) == 1.0) === 1) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index c895659a2d8be..be59b0af2c78e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class TokenizerSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) @@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } -class RegexTokenizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.RegexTokenizerSuite._ import testImplicits._ + def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = { + testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } + test("params") { ParamsSuite.checkParams(new RegexTokenizer) } @@ -105,14 +108,3 @@ class RegexTokenizerSuite } } -object RegexTokenizerSuite extends SparkFunSuite { - - def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("tokens", "wantedTokens") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 69a7b75e32eb7..e5675e31bbecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest with Logging { +class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { import testImplicits._ import VectorIndexerSuite.FeatureData @@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(vectorIndexer, model) - model.transform(densePoints1) // should work - model.transform(sparsePoints1) // should work + testTransformer[FeatureData](densePoints1, model, "indexed") { _ => } + testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => } + // If the data is local Dataset, it throws AssertionError directly. - intercept[AssertionError] { - model.transform(densePoints2).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called on " + + "vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2, + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } // If the data is distributed Dataset, it throws SparkException // which is the wrapper of AssertionError. - intercept[SparkException] { - model.transform(densePoints2.repartition(2)).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called " + + "on vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2.repartition(2), + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } intercept[SparkException] { vectorIndexer.fit(badPoints) @@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val categoryMaps = model.categoryMaps // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) - val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) - val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) - assert(featureAttrs.name === "indexed") - assert(featureAttrs.attributes.get.length === model.numFeatures) - categoricalFeatures.foreach { feature: Int => - val origValueSet = collectedData.map(_(feature)).toSet - val targetValueIndexSet = Range(0, origValueSet.size).toSet - val catMap = categoryMaps(feature) - assert(catMap.keys.toSet === origValueSet) // Correct categories - assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices - if (origValueSet.contains(0.0)) { - assert(catMap(0.0) === 0) // value 0 gets index 0 - } - // Check transformed data - assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) - // Check metadata - val featureAttr = featureAttrs(feature) - assert(featureAttr.index.get === feature) - featureAttr match { - case attr: BinaryAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - case attr: NominalAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - assert(attr.isOrdinal.get === false) - case _ => - throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) }.toDF("indexed") + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } - } - // Check numerical feature metadata. - Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) - .foreach { feature: Int => - val featureAttr = featureAttrs(feature) - featureAttr match { - case attr: NumericAttribute => - assert(featureAttr.index.get === feature) - case _ => - throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } } } catch { @@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext (sparsePoints1, sparsePoints1TestInvalid))) { val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") val model = vectorIndexer.fit(points) - intercept[SparkException] { - model.transform(pointsTestInvalid).collect() - } + testTransformerByInterceptingException[FeatureData]( + pointsTestInvalid, + model, + "VectorIndexer encountered invalid value", + "indexed") val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") val model1 = vectorIndexer1.fit(points) - val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) - assert(transformed1 === invalidTransformed1) - + val expected = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, 1.0), + Vectors.dense(1.0, 3.0, 2.0)) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } + testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") val model2 = vectorIndexer2.fit(points) - val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - assert(invalidTransformed2 === transformed1 ++ Array( - Vectors.dense(2.0, 2.0, 0.0), - Vectors.dense(0.0, 4.0, 2.0), - Vectors.dense(1.0, 3.0, 3.0)) - ) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, "indexed") { rows => + assert(rows.map(_(0)) == expected ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0))) + } } } @@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = - model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() - points.zip(indexedPoints).foreach { - case (orig: SparseVector, indexed: SparseVector) => - assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + points.zip(rows.map(_(0))).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } } } checkSparsity(sparsePoints1, maxCategories = 2) @@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. - val indexedPoints = model.transform(densePoints1WithMeta) - val transAttributes: Array[Attribute] = - AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get - featureAttributes.zip(transAttributes).foreach { case (orig, trans) => - assert(orig.name === trans.name) - (orig, trans) match { - case (orig: NumericAttribute, trans: NumericAttribute) => - assert(orig.max.nonEmpty && orig.max === trans.max) - case _ => + testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, "indexed") { rows => + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => // do nothing // TODO: Once input features marked as categorical are handled correctly, check that here. + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala index f6c9a76599fae..d89d10b320d84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamTest class VectorSizeHintSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,16 +38,23 @@ class VectorSizeHintSuite val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") val noSizeTransformer = new VectorSizeHint().setInputCol("vector") - intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noSizeTransformer, + "Failed to find a default value for size", + "vector") intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) val noInputColTransformer = new VectorSizeHint().setSize(2) - intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noInputColTransformer, + "Failed to find a default value for inputCol", + "vector") intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) } test("Adding size to column of vectors.") { - val size = 3 val vectorColName = "vector" val denseVector = Vectors.dense(1, 2, 3) @@ -66,12 +71,15 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrame) - assert( - AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, - "Transformer did not add expected size data.") - val numRows = withSize.collect().length - assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) { + rows => { + assert( + AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = rows.length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } } } @@ -93,14 +101,16 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrameWithMetadata) - - val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) - assert(newGroup.size === size, "Column has incorrect size metadata.") - assert( - newGroup.attributes.get === group.attributes.get, - "VectorSizeHint did not preserve attributes.") - withSize.collect + testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + vectorColName) { rows => + val newGroup = AttributeGroup.fromStructField(rows.head.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + } } } @@ -120,7 +130,11 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + testTransformerByInterceptingException[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + "Trying to set size of vectors in `vector` to 4 but size already set to 3.", + vectorColName) } } @@ -136,18 +150,36 @@ class VectorSizeHintSuite .setHandleInvalid("error") .setSize(3) - intercept[SparkException](sizeHint.transform(dataWithNull).collect()) - intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithNull, + sizeHint, + "Got null vector in VectorSizeHint", + "vector") + + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithShort, + sizeHint, + "VectorSizeHint Expecting a vector of size 3 but got 1", + "vector") sizeHint.setHandleInvalid("skip") - assert(sizeHint.transform(dataWithNull).count() === 1) - assert(sizeHint.transform(dataWithShort).count() === 1) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 1) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 1) + } sizeHint.setHandleInvalid("optimistic") - assert(sizeHint.transform(dataWithNull).count() === 2) - assert(sizeHint.transform(dataWithShort).count() === 2) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 2) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 2) + } } + test("read/write") { val sizeHint = new VectorSizeHint() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 1746ce53107c4..3d90f9d9ac764 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StructField, StructType} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { val slicer = new VectorSlicer().setInputCol("feature") @@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") - def validateResults(df: DataFrame): Unit = { - df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + def validateResults(rows: Seq[Row]): Unit = { + rows.foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 === vec2) } - val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) - val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected")) assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => assert(a === b) @@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 10682ba176aca..b59c4e7967338 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class Word2VecSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) - model.transform(docDF).select("result", "expected").collect().foreach { + testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } test("getVectors") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val spark = this.spark - import spark.implicits._ - val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -227,8 +213,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec works with input that is non-nullable (NGram)") { - val spark = this.spark - import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") @@ -243,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(ngramDF) // Just test that this transformation succeeds - model.transform(ngramDF).collect() + testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, model, "result") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 17678aa611a48..795fd0e2ac0e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,9 +22,10 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.ml.Transformer import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -62,8 +63,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val columnNames = dataframe.schema.fieldNames val stream = MemoryStream[A] - val streamDF = stream.toDS().toDF(columnNames: _*) - + val columnsWithMetadata = dataframe.schema.map { structField => + col(structField.name).as(structField.name, structField.metadata) + } + val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*) val data = dataframe.as[A].collect() val streamOutput = transformer.transform(streamDF) @@ -108,5 +111,29 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => otherResultCols: _*)(globalCheckFunction) testTransformerOnDF(dataframe, transformer, firstResultCol, otherResultCols: _*)(globalCheckFunction) + } + + def testTransformerByInterceptingException[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + expectedMessagePart : String, + firstResultCol: String) { + + def hasExpectedMessage(exception: Throwable): Boolean = + exception.getMessage.contains(expectedMessagePart) || + (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) + + withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { + val exceptionOnDf = intercept[Throwable] { + testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnDf)) + } + withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { + val exceptionOnStreamData = intercept[Throwable] { + testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnStreamData)) + } } } From 4f5bad615b47d743b8932aea1071652293981604 Mon Sep 17 00:00:00 2001 From: smallory Date: Thu, 15 Mar 2018 11:58:54 +0900 Subject: [PATCH 162/593] [SPARK-23642][DOCS] AccumulatorV2 subclass isZero scaladoc fix Added/corrected scaladoc for isZero on the DoubleAccumulator, CollectionAccumulator, and LongAccumulator subclasses of AccumulatorV2, particularly noting where there are requirements in addition to having a value of zero in order to return true. ## What changes were proposed in this pull request? Three scaladoc comments are updated in AccumulatorV2.scala No changes outside of comment blocks were made. ## How was this patch tested? Running "sbt unidoc", fixing style errors found, and reviewing the resulting local scaladoc in firefox. Author: smallory Closes #20790 from smallory/patch-1. --- .../main/scala/org/apache/spark/util/AccumulatorV2.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index f4a736d6d439a..0f84ea9752cf5 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -290,7 +290,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private var _count = 0L /** - * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + * * @since 2.0.0 */ override def isZero: Boolean = _sum == 0L && _count == 0 @@ -368,6 +369,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private var _sum = 0.0 private var _count = 0L + /** + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + */ override def isZero: Boolean = _sum == 0.0 && _count == 0 override def copy(): DoubleAccumulator = { @@ -441,6 +445,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) + /** + * Returns false if this accumulator instance has any values in it. + */ override def isZero: Boolean = _list.isEmpty override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator From 7c3e8995f18a1fb57c1f2c1b98a1d47590e28f38 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 15 Mar 2018 00:04:28 -0700 Subject: [PATCH 163/593] [SPARK-23533][SS] Add support for changing ContinuousDataReader's startOffset ## What changes were proposed in this pull request? As discussion in #20675, we need add a new interface `ContinuousDataReaderFactory` to support the requirements of setting start offset in Continuous Processing. ## How was this patch tested? Existing UT. Author: Yuanjian Li Closes #20689 from xuanyuanking/SPARK-23533. --- .../sql/kafka010/KafkaContinuousReader.scala | 11 +++++- .../reader/ContinuousDataReaderFactory.java | 35 +++++++++++++++++++ .../ContinuousRateStreamSource.scala | 15 +++++++- 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index ecd1170321f3f..6e56b0a72d671 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -164,7 +164,16 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + require(kafkaOffset.topicPartition == topicPartition, + s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") + new KafkaContinuousDataReader( + topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } + override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java new file mode 100644 index 0000000000000..a61697649c43e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; + +/** + * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can + * implement this interface to provide creating {@link DataReader} with particular offset. + */ +@InterfaceStability.Evolving +public interface ContinuousDataReaderFactory extends DataReaderFactory { + /** + * Create a DataReader with particular offset as its startOffset. + * + * @param offset offset want to set as the DataReader's startOffset. + */ + DataReader createDataReaderWithOffset(PartitionOffset offset); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b63d8d3e20650..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -106,7 +106,20 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends DataReaderFactory[Row] { + extends ContinuousDataReaderFactory[Row] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousDataReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) + } + override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) From 56e8f48a43eb51e8582db2461a585b13a771a00a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Mar 2018 10:55:33 -0700 Subject: [PATCH 164/593] [SPARK-23695][PYTHON] Fix the error message for Kinesis streaming tests ## What changes were proposed in this pull request? This PR proposes to fix the error message for Kinesis in PySpark when its jar is missing but explicitly enabled. ```bash ENABLE_KINESIS_TESTS=1 SPARK_TESTING=1 bin/pyspark pyspark.streaming.tests ``` Before: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1572, in % kinesis_asl_assembly_dir) + NameError: name 'kinesis_asl_assembly_dir' is not defined ``` After: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1576, in "You need to build Spark with 'build/sbt -Pkinesis-asl " Exception: Failed to find Spark Streaming Kinesis assembly jar in /.../spark/external/kinesis-asl-assembly. You need to build Spark with 'build/sbt -Pkinesis-asl assembly/package streaming-kinesis-asl-assembly/assembly'or 'build/mvn -Pkinesis-asl package' before running this test. ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20834 from HyukjinKwon/minor-variable. --- python/pyspark/streaming/tests.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 71f8101e34c50..7dde7c0928c08 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1503,10 +1503,13 @@ def search_flume_assembly_jar(): return jars[0] -def search_kinesis_asl_assembly_jar(): +def _kinesis_asl_assembly_dir(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") + + +def search_kinesis_asl_assembly_jar(): + jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") if not jars: return None elif len(jars) > 1: @@ -1569,7 +1572,7 @@ def search_kinesis_asl_assembly_jar(): else: raise Exception( ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % kinesis_asl_assembly_dir) + + % _kinesis_asl_assembly_dir()) + "You need to build Spark with 'build/sbt -Pkinesis-asl " "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") From 15c3c983008557165cc91713ddaf2dbd6d5a506c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Mar 2018 19:54:58 +0100 Subject: [PATCH 165/593] [HOT-FIX] Fix SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88263/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88260/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88257/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88224/testReport These tests all failed: ``` org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:157) at org.apache.spark.memory.MemoryConsumer.allocateArray(MemoryConsumer.java:98) at org.apache.spark.unsafe.map.BytesToBytesMap.allocate(BytesToBytesMap.java:787) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:204) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:219) ... ``` This PR ignore this test. ## How was this patch tested? N/A Author: Yuming Wang Closes #20835 from wangyum/SPARK-23598. --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4b40e4ef7571c..9180a22c260f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -310,7 +310,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + ignore("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") for (i <- 0 until 70) { From 7618896e855579f111dd92cd76794a5672a087e5 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 15 Mar 2018 17:04:39 -0700 Subject: [PATCH 166/593] [SPARK-23658][LAUNCHER] InProcessAppHandle uses the wrong class in getLogger ## What changes were proposed in this pull request? Changed `Logger` in `InProcessAppHandle` to use `InProcessAppHandle` instead of `ChildProcAppHandle` Author: Sahil Takiar Closes #20815 from sahilTakiar/master. --- .../main/java/org/apache/spark/launcher/InProcessAppHandle.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 4b740d3fad20e..15fbca0facef2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -25,7 +25,7 @@ class InProcessAppHandle extends AbstractAppHandle { private static final String THREAD_NAME_FMT = "spark-app-%d: '%s'"; - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(InProcessAppHandle.class.getName()); private static final AtomicLong THREAD_IDS = new AtomicLong(); // Avoid really long thread names. From 18f8575e0166c6997569358d45bdae2cf45bf624 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Mar 2018 17:12:01 -0700 Subject: [PATCH 167/593] [SPARK-23671][CORE] Fix condition to enable the SHS thread pool. Author: Marcelo Vanzin Closes #20814 from vanzin/SPARK-23671. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f9d0b5ee4e23e..ace6d9e00c838 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -173,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (Utils.isTesting) { + if (!Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() From 3675af7247e841e9a689666dc20891ba55c612b3 Mon Sep 17 00:00:00 2001 From: Ye Zhou Date: Thu, 15 Mar 2018 17:15:53 -0700 Subject: [PATCH 168/593] [SPARK-23608][CORE][WEBUI] Add synchronization in SHS between attachSparkUI and detachSparkUI functions to avoid concurrent modification issue to Jetty Handlers Jetty handlers are dynamically attached/detached while SHS is running. But the attach and detach operations might be taking place at the same time due to the async in load/clear in Guava Cache. ## What changes were proposed in this pull request? Add synchronization between attachSparkUI and detachSparkUI in SHS. ## How was this patch tested? With this patch, the jetty handlers missing issue never happens again in our production cluster SHS. Author: Ye Zhou Closes #20744 from zhouyejoe/SPARK-23608. --- .../apache/spark/deploy/history/HistoryServer.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 0ec4afad0308c..611fa563a7cd9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -150,14 +150,18 @@ class HistoryServer( ui: SparkUI, completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") - ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) + handlers.synchronized { + ui.getHandlers.foreach(attachHandler) + addFilters(ui.getHandlers, conf) + } } /** Detach a reconstructed UI from this server. Only valid after bind(). */ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") - ui.getHandlers.foreach(detachHandler) + handlers.synchronized { + ui.getHandlers.foreach(detachHandler) + } provider.onUIDetached(appId, attemptId, ui) } From c2632edebd978716dbfa7874a2fc0a8f5a4a9951 Mon Sep 17 00:00:00 2001 From: myroslavlisniak Date: Thu, 15 Mar 2018 17:20:17 -0700 Subject: [PATCH 169/593] [SPARK-23670][SQL] Fix memory leak on SparkPlanGraphWrapper Clean up SparkPlanGraphWrapper objects from InMemoryStore together with cleaning up SQLExecutionUIData existing unit test was extended to check also SparkPlanGraphWrapper object count vanzin Author: myroslavlisniak Closes #20813 from myroslavlisniak/master. --- .../apache/spark/sql/execution/ui/SQLAppStatusListener.scala | 5 ++++- .../apache/spark/sql/execution/ui/SQLAppStatusStore.scala | 4 ++++ .../spark/sql/execution/ui/SQLAppStatusListenerSuite.scala | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 53fb9a0cc21cf..71e9f93c4566e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -334,7 +334,10 @@ class SQLAppStatusListener( val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) - toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } + toDelete.foreach { e => + kvstore.delete(e.getClass(), e.executionId) + kvstore.delete(classOf[SparkPlanGraphWrapper], e.executionId) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 9a76584717f42..241001a857c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -54,6 +54,10 @@ class SQLAppStatusStore( store.count(classOf[SQLExecutionUIData]) } + def planGraphCount(): Long = { + store.count(classOf[SparkPlanGraphWrapper]) + } + def executionMetrics(executionId: Long): Map[Long, String] = { def metricsFromStore(): Option[Map[Long, String]] = { val exec = store.read(classOf[SQLExecutionUIData], executionId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 85face3994fd4..f3f08839c1d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -611,6 +611,7 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { sc.listenerBus.waitUntilEmpty(10000) val statusStore = spark.sharedState.statusStore assert(statusStore.executionsCount() <= 50) + assert(statusStore.planGraphCount() <= 50) // No live data should be left behind after all executions end. assert(statusStore.listener.get.noLiveData()) } From ca83526de55f0f8784df58cc8b7c0a7cb0c96e23 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 16 Mar 2018 15:12:26 +0800 Subject: [PATCH 170/593] [SPARK-23644][CORE][UI] Use absolute path for REST call in SHS ## What changes were proposed in this pull request? SHS is using a relative path for the REST API call to get the list of the application is a relative path call. In case of the SHS being consumed through a proxy, it can be an issue if the path doesn't end with a "/". Therefore, we should use an absolute path for the REST call as it is done for all the other resources. ## How was this patch tested? manual tests Before the change: ![screen shot 2018-03-10 at 4 22 02 pm](https://user-images.githubusercontent.com/8821783/37244190-8ccf9d40-2485-11e8-8fa9-345bc81472fc.png) After the change: ![screen shot 2018-03-10 at 4 36 34 pm 1](https://user-images.githubusercontent.com/8821783/37244201-a1922810-2485-11e8-8856-eeab2bf5e180.png) Author: Marco Gaido Closes #20794 from mgaido91/SPARK-23644. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index f0b2a5a833a99..abc2ec0fa6531 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -113,7 +113,7 @@ $(document).ready(function() { status: (requestedIncomplete ? "running" : "completed") }; - $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { + $.getJSON(uiRoot + "/api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { @@ -151,7 +151,7 @@ $(document).ready(function() { "showCompletedColumns": !requestedIncomplete, } - $.get("static/historypage-template.html", function(template) { + $.get(uiRoot + "/static/historypage-template.html", function(template) { var sibling = historySummary.prev(); historySummary.detach(); var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data)); From c952000487ee003200221b3c4e25dcb06e359f0a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Mar 2018 16:22:03 +0800 Subject: [PATCH 171/593] [SPARK-23635][YARN] AM env variable should not overwrite same name env variable set through spark.executorEnv. ## What changes were proposed in this pull request? In the current Spark on YARN code, AM always will copy and overwrite its env variables to executors, so we cannot set different values for executors. To reproduce issue, user could start spark-shell like: ``` ./bin/spark-shell --master yarn-client --conf spark.executorEnv.SPARK_ABC=executor_val --conf spark.yarn.appMasterEnv.SPARK_ABC=am_val ``` Then check executor env variables by ``` sc.parallelize(1 to 1).flatMap \{ i => sys.env.toSeq }.collect.foreach(println) ``` We will always get `am_val` instead of `executor_val`. So we should not let AM to overwrite specifically set executor env variables. ## How was this patch tested? Added UT and tested in local cluster. Author: jerryshao Closes #20799 from jerryshao/SPARK-23635. --- .../spark/deploy/yarn/ExecutorRunnable.scala | 22 +++++++----- .../spark/deploy/yarn/YarnClusterSuite.scala | 36 +++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 3f4d236571ffd..ab08698035c98 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -220,12 +220,6 @@ private[yarn] class ExecutorRunnable( val env = new HashMap[String, String]() Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, @@ -233,6 +227,20 @@ private[yarn] class ExecutorRunnable( ) val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } + + sparkConf.getExecutorEnv.foreach { case (key, value) => + if (key == Environment.CLASSPATH.name()) { + // If the key of env variable is CLASSPATH, we assume it is a path and append it. + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } else { + // For other env variables, simply overwrite the value. + env(key) = value + } + } + // Add log urls container.foreach { c => sys.env.get("SPARK_USER").foreach { user => @@ -245,8 +253,6 @@ private[yarn] class ExecutorRunnable( } } - System.getenv().asScala.filterKeys(_.startsWith("SPARK")) - .foreach { case (k, v) => env(k) = v } env } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 33d400a5b1b2e..a129be7c06b53 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -225,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { finalState should be (SparkAppHandle.State.FAILED) } + test("executor env overwrite AM env in client mode") { + testExecutorEnv(true) + } + + test("executor env overwrite AM env in cluster mode") { + testExecutorEnv(false) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -305,6 +313,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, executorResult, "OVERRIDDEN") } + private def testExecutorEnv(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass), + appArgs = Seq(result.getAbsolutePath), + extraConf = Map( + "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val", + "spark.executorEnv.TEST_ENV" -> "executor_val" + ) + ) + checkResult(finalState, result, "true") + } } private[spark] class SaveExecutorInfo extends SparkListener { @@ -526,3 +545,20 @@ private object SparkContextTimeoutApp { } } + +private object ExecutorEnvTestApp { + + def main(args: Array[String]): Unit = { + val status = args(0) + val sparkConf = new SparkConf() + val sc = new SparkContext(sparkConf) + val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap + val result = sparkConf.getExecutorEnv.forall { case (k, v) => + executorEnvs.get(k).contains(v) + } + + Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + sc.stop() + } + +} From 5414abca4fec6a68174c34d22d071c20027e959d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 16 Mar 2018 09:36:30 -0700 Subject: [PATCH 172/593] [SPARK-23553][TESTS] Tests should not assume the default value of `spark.sql.sources.default` ## What changes were proposed in this pull request? Currently, some tests have an assumption that `spark.sql.sources.default=parquet`. In fact, that is a correct assumption, but that assumption makes it difficult to test new data source format. This PR aims to - Improve test suites more robust and makes it easy to test new data sources in the future. - Test new native ORC data source with the full existing Apache Spark test coverage. As an example, the PR uses `spark.sql.sources.default=orc` during reviews. The value should be `parquet` when this PR is accepted. ## How was this patch tested? Pass the Jenkins with updated tests. Author: Dongjoon Hyun Closes #20705 from dongjoon-hyun/SPARK-23553. --- python/pyspark/sql/readwriter.py | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +-- .../columnar/InMemoryColumnarQuerySuite.scala | 5 +- .../sql/execution/command/DDLSuite.scala | 11 ++- .../ParquetPartitionDiscoverySuite.scala | 10 +++ .../sql/test/DataFrameReaderWriterSuite.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 72 +++++++++---------- .../PartitionProviderCompatibilitySuite.scala | 6 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 11 +-- .../sql/hive/execution/SQLQuerySuite.scala | 19 ++--- 11 files changed, 81 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 803f561ece67b..facc16bc53108 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -147,8 +147,8 @@ def load(self, path=None, format=None, schema=None, **options): or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options - >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, - ... opt2=1, opt3='str') + >>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_partitioned', + ... opt1=True, opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8f14575c3325f..640affc10ee58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2150,7 +2150,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { - sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + val provider = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE tbl(i INT, j STRING) USING $provider") checkAnswer(sql("SELECT i, j FROM tbl"), Nil) Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") @@ -2474,9 +2475,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { withTempDir { dir => - val parquetDir = new File(dir, "parquet").getCanonicalPath - spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) - spark.read.parquet(parquetDir) + val dataDir = new File(dir, "data").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(dataDir) + spark.read.load(dataDir) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index dc1766fb9a785..26b63e8e8490f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -487,7 +487,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { - withSQLConf("spark.sql.cbo.enabled" -> "true") { + // This test case depends on the size of parquet in statistics. + withSQLConf( + SQLConf.CBO_ENABLED.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "parquet") { withTempPath { workDir => withTable("table1") { val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4041176262426..4df8fbfe1c0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -154,10 +154,15 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") val e = intercept[AnalysisException] { - Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + val format = if (spark.sessionState.conf.defaultDataSourceName.equalsIgnoreCase("json")) { + "orc" + } else { + "json" + } + Seq(5 -> "e").toDF("i", "j").write.mode("append").format(format).saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index edb3da904d10d..e887c9734a8b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -57,6 +57,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val timeZone = TimeZone.getDefault() val timeZoneId = timeZone.getID + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "parquet") + } + + protected override def afterAll(): Unit = { + spark.conf.unset(SQLConf.DEFAULT_DATA_SOURCE_NAME.key) + super.afterAll() + } + test("column type inference") { def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { assert(inferPartitionColumnValue(raw, true, timeZone) === literal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a707a88dfa670..14b1feb2adc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -562,7 +562,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - sql("CREATE TABLE same_name(id LONG) USING parquet") + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") spark.range(10).createTempView("same_name") spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 859099a321bf7..d93215fefb810 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -591,7 +591,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (ArrayType)") { - withTable("arrayInParquet") { + withTable("array") { { val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") val expectedSchema = @@ -604,9 +604,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("arrayInParquet") + .saveAsTable("array") } { @@ -621,25 +620,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("arrayInParquet") + .insertInto("array") } (Tuple1(Seq(4, 5)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + .saveAsTable("array") // This one internally calls df2.insertInto. (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") + .saveAsTable("array") - sparkSession.catalog.refreshTable("arrayInParquet") + sparkSession.catalog.refreshTable("array") checkAnswer( - sql("SELECT a FROM arrayInParquet"), + sql("SELECT a FROM array"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -648,7 +646,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (MapType)") { - withTable("mapInParquet") { + withTable("map") { { val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") val expectedSchema = @@ -661,9 +659,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("mapInParquet") + .saveAsTable("map") } { @@ -678,27 +675,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("mapInParquet") + .insertInto("map") } (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + .saveAsTable("map") // This one internally calls df2.insertInto. (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") + .saveAsTable("map") - sparkSession.catalog.refreshTable("mapInParquet") + sparkSession.catalog.refreshTable("map") checkAnswer( - sql("SELECT a FROM mapInParquet"), + sql("SELECT a FROM map"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -852,52 +846,52 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv (from to to).map(i => i -> s"str$i").toDF("c1", "c2") } - withTable("insertParquet") { - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + withTable("t") { + createDF(0, 9).write.saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.saveAsTable("t") } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(20, 29).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") + createDF(30, 39).write.saveAsTable("t") } - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + createDF(40, 49).write.mode(SaveMode.Append).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (50 to 59).map(i => Row(i, s"str$i"))) - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (70 to 79).map(i => Row(i, s"str$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 9440a17677ebf..80afc9d8f44bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -37,11 +37,11 @@ class PartitionProviderCompatibilitySuite spark.range(5).selectExpr("id as fieldOne", "id as partCol").write .partitionBy("partCol") .mode("overwrite") - .parquet(dir.getAbsolutePath) + .save(dir.getAbsolutePath) spark.sql(s""" |create table $tableName (fieldOne long, partCol int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${dir.toURI}") |partitioned by (partCol)""".stripMargin) } @@ -358,7 +358,7 @@ class PartitionProviderCompatibilitySuite try { spark.sql(s""" |create table test (id long, P1 int, P2 int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${base.toURI}") |partitioned by (P1, P2)""".stripMargin) spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.toURI}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 54d3962a46b4d..1a86c604d5da3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -417,7 +417,7 @@ class PartitionedTablePerfStatsSuite import spark.implicits._ Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) HiveCatalogMetrics.reset() - spark.read.parquet(dir.getAbsolutePath) + spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 65be244418670..db76ec9d084cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1658,8 +1658,8 @@ class HiveDDLSuite Seq(5 -> "e").toDF("i", "j") .write.format("hive").mode("append").saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `HiveFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format `HiveFileFormat`.")) } } @@ -1709,11 +1709,12 @@ class HiveDDLSuite spark.sessionState.catalog.getTableMetadata(TableIdentifier(tblName)).schema.map(_.name) } + val provider = spark.sessionState.conf.defaultDataSourceName withTable("t", "t1", "t2", "t3", "t4", "t5", "t6") { - sql("CREATE TABLE t(a int, b int, c int, d int) USING parquet PARTITIONED BY (d, b)") + sql(s"CREATE TABLE t(a int, b int, c int, d int) USING $provider PARTITIONED BY (d, b)") assert(getTableColumns("t") == Seq("a", "c", "d", "b")) - sql("CREATE TABLE t1 USING parquet PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") + sql(s"CREATE TABLE t1 USING $provider PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") assert(getTableColumns("t1") == Seq("a", "c", "d", "b")) Seq((1, 1, 1, 1)).toDF("a", "b", "c", "d").write.partitionBy("d", "b").saveAsTable("t2") @@ -1723,7 +1724,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") + sql(s"CREATE TABLE t3 USING $provider LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index baabc4a3bca2c..73f83d593bbfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -516,24 +516,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS with default fileformat") { val table = "ctas1" val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" - withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { - withSQLConf("hive.default.fileformat" -> "textfile") { + Seq("orc", "parquet").foreach { dataSourceFormat => + withSQLConf( + SQLConf.CONVERT_CTAS.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> dataSourceFormat, + "hive.default.fileformat" -> "textfile") { withTable(table) { sql(ctas) - // We should use parquet here as that is the default datasource fileformat. The default - // datasource file format is controlled by `spark.sql.sources.default` configuration. + // The default datasource file format is controlled by `spark.sql.sources.default`. // This testcase verifies that setting `hive.default.fileformat` has no impact on // the target table's fileformat in case of CTAS. - assert(sessionState.conf.defaultDataSourceName === "parquet") - checkRelation(tableName = table, isDataSourceTable = true, format = "parquet") + checkRelation(tableName = table, isDataSourceTable = true, format = dataSourceFormat) } } - withSQLConf("spark.sql.sources.default" -> "orc") { - withTable(table) { - sql(ctas) - checkRelation(tableName = table, isDataSourceTable = true, format = "orc") - } - } } } From dffeac3691daa620446ae949c5b147518d128e08 Mon Sep 17 00:00:00 2001 From: Sebastian Arzt Date: Fri, 16 Mar 2018 12:25:58 -0500 Subject: [PATCH 173/593] [SPARK-18371][STREAMING] Spark Streaming backpressure generates batch with large number of records MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Omit rounding of backpressure rate. Effects: - no batch with large number of records is created when rate from PID estimator is one - the number of records per batch and partition is more fine-grained improving backpressure accuracy ## How was this patch tested? This was tested by running: - `mvn test -pl external/kafka-0-8` - `mvn test -pl external/kafka-0-10` - a streaming application which was suffering from the issue JasonMWhite The contribution is my original work and I license the work to the project under the project’s open source license Author: Sebastian Arzt Closes #17774 from arzt/kafka-back-pressure. --- .../kafka010/DirectKafkaInputDStream.scala | 6 +-- .../kafka010/DirectKafkaStreamSuite.scala | 48 +++++++++++++++++ .../kafka/DirectKafkaInputDStream.scala | 6 +-- .../kafka/DirectKafkaStreamSuite.scala | 51 +++++++++++++++++++ 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0fa3287f36db8..9cb2448fea0f4 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -138,17 +138,17 @@ private[spark] class DirectKafkaInputDStream[K, V]( lagPerPartition.map { case (tp, lag) => val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 453b5e5ab20d3..8524743ee2846 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -617,6 +617,54 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = getKafkaParams() + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimateRate = 1L + val fromOffsets = Map( + new TopicPartition(topic, 0) -> 0L, + new TopicPartition(topic, 1) -> 0L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 0L + ) + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { + currentOffsets = fromOffsets + override val rateController = Some(new ConstantRateController(id, null, estimateRate)) + } + } + + val offsets = Map[TopicPartition, Long]( + new TopicPartition(topic, 0) -> 0, + new TopicPartition(topic, 1) -> 100L, + new TopicPartition(topic, 2) -> 200L, + new TopicPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 1) -> 10L, + new TopicPartition(topic, 2) -> 20L, + new TopicPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d52c230eb7849..d6dd0744441e4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -104,17 +104,17 @@ class DirectKafkaInputDStream[ val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 06ef5bc3f8bd0..3fea6cfd910bf 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -456,6 +456,57 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimatedRate = 1L + val kafkaStream = withClue("Error creating direct stream") { + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val fromOffsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 0L, + TopicAndPartition(topic, 2) -> 0L, + TopicAndPartition(topic, 3) -> 0L + ) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, fromOffsets, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, null) { + override def getLatestRate() = estimatedRate + }) + } + } + + val offsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 100L, + TopicAndPartition(topic, 2) -> 200L, + TopicAndPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + TopicAndPartition(topic, 0) -> 1L, + TopicAndPartition(topic, 1) -> 10L, + TopicAndPartition(topic, 2) -> 20L, + TopicAndPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { From 88d8de9260edf6e9d5449ff7ef6e35d16051fc9f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 16 Mar 2018 18:28:16 +0100 Subject: [PATCH 174/593] [SPARK-23581][SQL] Add interpreted unsafe projection ## What changes were proposed in this pull request? We currently can only create unsafe rows using code generation. This is a problem for situations in which code generation fails. There is no fallback, and as a result we cannot execute the query. This PR adds an interpreted version of `UnsafeProjection`. The implementation is modeled after `InterpretedMutableProjection`. It stores the expression results in a `GenericInternalRow`, and it then uses a conversion function to convert the `GenericInternalRow` into an `UnsafeRow`. This PR does not implement the actual code generated to interpreted fallback logic. This will be done in a follow-up. ## How was this patch tested? I am piggybacking on exiting `UnsafeProjection` tests, and I have added an interpreted version for each of these. Author: Herman van Hovell Closes #20750 from hvanhovell/SPARK-23581. --- .../codegen/UnsafeArrayWriter.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 30 +- .../expressions/codegen/UnsafeWriter.java | 43 ++ .../sql/catalyst/expressions/Expression.scala | 26 ++ .../InterpretedUnsafeProjection.scala | 366 ++++++++++++++++++ .../MonotonicallyIncreasingID.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 19 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../expressions/randomExpressions.scala | 6 +- .../expressions/ComplexTypeSuite.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 20 +- .../expressions/ObjectExpressionsSuite.scala | 21 +- .../catalyst/expressions/ScalaUDFSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 56 +-- 14 files changed, 555 insertions(+), 74 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 791e8d80e6cba..82cd1b24607e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -30,7 +30,7 @@ * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. */ -public class UnsafeArrayWriter { +public final class UnsafeArrayWriter extends UnsafeWriter { private BufferHolder holder; @@ -83,7 +83,7 @@ private long getElementOffset(int ordinal, int elementSize) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, long currentCursor, int size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { assertIndexIsValid(ordinal); final long relativeOffset = currentCursor - startingOffset; final long offsetAndSize = (relativeOffset << 32) | (long)size; @@ -96,49 +96,31 @@ private void setNullBit(int ordinal) { BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); } - public void setNullBoolean(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); - } - - public void setNullByte(int ordinal) { + public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); } - public void setNullShort(int ordinal) { + public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); } - public void setNullInt(int ordinal) { + public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); } - public void setNullLong(int ordinal) { + public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); } - public void setNullFloat(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); - } - - public void setNullDouble(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); - } - - public void setNull(int ordinal) { setNullLong(ordinal); } + public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 5d9515c0725da..2620bbcfb87a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -38,7 +38,7 @@ * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. */ -public class UnsafeRowWriter { +public final class UnsafeRowWriter extends UnsafeWriter { private final BufferHolder holder; // The offset of the global buffer where we start to write this row. @@ -93,18 +93,38 @@ public void setNullAt(int ordinal) { Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); } + @Override + public void setNull1Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull2Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull4Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull8Bytes(int ordinal) { + setNullAt(ordinal); + } + public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, long size) { + public void setOffsetAndSize(int ordinal, int size) { setOffsetAndSize(ordinal, holder.cursor, size); } - public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { final long relativeOffset = currentCursor - startingOffset; final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | size; + final long offsetAndSize = (relativeOffset << 32) | (long) size; Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); } @@ -174,7 +194,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { if (input == null || !input.changePrecision(precision, scale)) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); // keep the offset for future update - setOffsetAndSize(ordinal, 0L); + setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); assert bytes.length <= 16; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java new file mode 100644 index 0000000000000..c94b5c7a367ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class for writing Unsafe* structures. + */ +public abstract class UnsafeWriter { + public abstract void setNull1Bytes(int ordinal); + public abstract void setNull2Bytes(int ordinal); + public abstract void setNull4Bytes(int ordinal); + public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); + public abstract void write(int ordinal, byte value); + public abstract void write(int ordinal, short value); + public abstract void write(int ordinal, int value); + public abstract void write(int ordinal, long value); + public abstract void write(int ordinal, float value); + public abstract void write(int ordinal, double value); + public abstract void write(int ordinal, Decimal input, int precision, int scale); + public abstract void write(int ordinal, UTF8String input); + public abstract void write(int ordinal, byte[] input); + public abstract void write(int ordinal, CalendarInterval input); + public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ed90b185865a0..d7f9e38915dd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -328,6 +328,32 @@ trait Nondeterministic extends Expression { protected def evalInternal(input: InternalRow): Any } +/** + * An expression that contains mutable state. A stateful expression is always non-deterministic + * because the results it produces during evaluation are not only dependent on the given input + * but also on its internal state. + * + * The state of the expressions is generally not exposed in the parameter list and this makes + * comparing stateful expressions problematic because similar stateful expressions (with the same + * parameter list) but with different internal state will be considered equal. This is especially + * problematic during tree transformations. In order to counter this the `fastEquals` method for + * stateful expressions only returns `true` for the same reference. + * + * A stateful expression should never be evaluated multiple times for a single row. This should + * only be a problem for interpreted execution. This can be prevented by creating fresh copies + * of the stateful expression before execution, these can be made using the `freshCopy` function. + */ +trait Stateful extends Nondeterministic { + /** + * Return a fresh uninitialized copy of the stateful expression. + */ + def freshCopy(): Stateful + + /** + * Only the same reference is considered equal. + */ + override def fastEquals(other: TreeNode[_]): Boolean = this eq other +} /** * A leaf expression, i.e. one without any child expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala new file mode 100644 index 0000000000000..0da5ece7e47fe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{UserDefinedType, _} +import org.apache.spark.unsafe.Platform + +/** + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer + * should copy the row if it is being buffered. This class is not thread safe. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { + import InterpretedUnsafeProjection._ + + /** Number of (top level) fields in the resulting row. */ + private[this] val numFields = expressions.length + + /** Array that expression results. */ + private[this] val values = new Array[Any](numFields) + + /** The row representing the expression results. */ + private[this] val intermediate = new GenericInternalRow(values) + + /** The row returned by the projection. */ + private[this] val result = new UnsafeRow(numFields) + + /** The buffer which holds the resulting row's backing data. */ + private[this] val holder = new BufferHolder(result, numFields * 32) + + /** The writer that writes the intermediate result to the result row. */ + private[this] val writer: InternalRow => Unit = { + val rowWriter = new UnsafeRowWriter(holder, numFields) + val baseWriter = generateStructWriter( + holder, + rowWriter, + expressions.map(e => StructField("", e.dataType, e.nullable))) + if (!expressions.exists(_.nullable)) { + // No nullable fields. The top-level null bit mask will always be zeroed out. + baseWriter + } else { + // Zero out the null bit mask before we write the row. + row => { + rowWriter.zeroOutNullBytes() + baseWriter(row) + } + } + } + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + override def apply(row: InternalRow): UnsafeRow = { + // Put the expression results in the intermediate row. + var i = 0 + while (i < numFields) { + values(i) = expressions(i).eval(row) + i += 1 + } + + // Write the intermediate row to an unsafe row. + holder.reset() + writer(intermediate) + result.setTotalSize(holder.totalSize()) + result + } +} + +/** + * Helper functions for creating an [[InterpretedUnsafeProjection]]. + */ +object InterpretedUnsafeProjection extends UnsafeProjectionCreator { + + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedUnsafeProjection(cleanedExpressions.toArray) + } + + /** + * Generate a struct writer function. The generated function writes an [[InternalRow]] to the + * given buffer using the given [[UnsafeRowWriter]]. + */ + private def generateStructWriter( + bufferHolder: BufferHolder, + rowWriter: UnsafeRowWriter, + fields: Array[StructField]): InternalRow => Unit = { + val numFields = fields.length + + // Create field writers. + val fieldWriters = fields.map { field => + generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + } + // Create basic writer. + row => { + var i = 0 + while (i < numFields) { + fieldWriters(i).apply(row, i) + i += 1 + } + } + } + + /** + * Generate a writer function for a struct field, array element, map key or map value. The + * generated function writes the element at an index in a [[SpecializedGetters]] object (row + * or array) to the given buffer using the given [[UnsafeWriter]]. + */ + private def generateFieldWriter( + bufferHolder: BufferHolder, + writer: UnsafeWriter, + dt: DataType, + nullable: Boolean): (SpecializedGetters, Int) => Unit = { + + // Create the the basic writer. + val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match { + case BooleanType => + (v, i) => writer.write(i, v.getBoolean(i)) + + case ByteType => + (v, i) => writer.write(i, v.getByte(i)) + + case ShortType => + (v, i) => writer.write(i, v.getShort(i)) + + case IntegerType | DateType => + (v, i) => writer.write(i, v.getInt(i)) + + case LongType | TimestampType => + (v, i) => writer.write(i, v.getLong(i)) + + case FloatType => + (v, i) => writer.write(i, v.getFloat(i)) + + case DoubleType => + (v, i) => writer.write(i, v.getDouble(i)) + + case DecimalType.Fixed(precision, scale) => + (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) + + case CalendarIntervalType => + (v, i) => writer.write(i, v.getInterval(i)) + + case BinaryType => + (v, i) => writer.write(i, v.getBinary(i)) + + case StringType => + (v, i) => writer.write(i, v.getUTF8String(i)) + + case StructType(fields) => + val numFields = fields.length + val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) + val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getStruct(i, fields.length) match { + case row: UnsafeRow => + writeUnsafeData( + bufferHolder, + row.getBaseObject, + row.getBaseOffset, + row.getSizeInBytes) + case row => + // Nested struct. We don't know where this will start because a row can be + // variable length, so we need to update the offsets and zero out the bit mask. + rowWriter.reset() + structWriter.apply(row) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case ArrayType(elementType, containsNull) => + val arrayWriter = new UnsafeArrayWriter + val elementSize = getElementSize(elementType) + val elementWriter = generateFieldWriter( + bufferHolder, + arrayWriter, + elementType, + containsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case MapType(keyType, valueType, valueContainsNull) => + val keyArrayWriter = new UnsafeArrayWriter + val keySize = getElementSize(keyType) + val keyWriter = generateFieldWriter( + bufferHolder, + keyArrayWriter, + keyType, + nullable = false) + val valueArrayWriter = new UnsafeArrayWriter + val valueSize = getElementSize(valueType) + val valueWriter = generateFieldWriter( + bufferHolder, + valueArrayWriter, + valueType, + valueContainsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getMap(i) match { + case map: UnsafeMapData => + writeUnsafeData( + bufferHolder, + map.getBaseObject, + map.getBaseOffset, + map.getSizeInBytes) + case map => + // preserve 8 bytes to write the key array numBytes later. + bufferHolder.grow(8) + bufferHolder.cursor += 8 + + // Write the keys and write the numBytes of key array into the first 8 bytes. + writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) + Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + + // Write the values. + writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case udt: UserDefinedType[_] => + generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + + case NullType => + (_, _) => {} + + case _ => + throw new SparkException(s"Unsupported data type $dt") + } + + // Always wrap the writer with a null safe version. + dt match { + case _: UserDefinedType[_] => + // The null wrapper depends on the sql type and not on the UDT. + unsafeWriter + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + // We can't call setNullAt() for DecimalType with precision larger than 18, we call write + // directly. We can use the unwrapped writer directly. + unsafeWriter + case BooleanType | ByteType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull1Bytes(i) + } + } + case ShortType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull2Bytes(i) + } + } + case IntegerType | DateType | FloatType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull4Bytes(i) + } + } + case _ => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull8Bytes(i) + } + } + } + } + + /** + * Get the number of bytes elements of a data type will occupy in the fixed part of an + * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an + * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length + * portion of the array object. Primitives take up to 8 bytes, depending on the size of the + * underlying data type. + */ + private def getElementSize(dataType: DataType): Int = dataType match { + case NullType | StringType | BinaryType | CalendarIntervalType | + _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 + case _ => dataType.defaultSize + } + + /** + * Write an array to the buffer. If the array is already in serialized form (an instance of + * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element + * copy. + */ + private def writeArray( + bufferHolder: BufferHolder, + arrayWriter: UnsafeArrayWriter, + elementWriter: (SpecializedGetters, Int) => Unit, + array: ArrayData, + elementSize: Int): Unit = array match { + case unsafe: UnsafeArrayData => + writeUnsafeData( + bufferHolder, + unsafe.getBaseObject, + unsafe.getBaseOffset, + unsafe.getSizeInBytes) + case _ => + val numElements = array.numElements() + arrayWriter.initialize(bufferHolder, numElements, elementSize) + var i = 0 + while (i < numElements) { + elementWriter.apply(array, i) + i += 1 + } + } + + /** + * Write an opaque block of data to the buffer. This is used to copy + * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. + */ + private def writeUnsafeData( + bufferHolder: BufferHolder, + baseObject: AnyRef, + baseOffset: Long, + sizeInBytes: Int) : Unit = { + bufferHolder.grow(sizeInBytes) + Platform.copyMemory( + baseObject, + baseOffset, + bufferHolder.buffer, + bufferHolder.cursor, + sizeInBytes) + bufferHolder.cursor += sizeInBytes + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 4523079060896..dd523d312e3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. """) -case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -79,4 +79,6 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" + + override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 64b94f0a2c103..3cd73682188bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -object UnsafeProjection { - +trait UnsafeProjectionCreator { /** * Returns an UnsafeProjection for given StructType. * @@ -127,13 +126,13 @@ object UnsafeProjection { } /** - * Returns an UnsafeProjection for given sequence of Expressions (bounded). + * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) - GenerateUnsafeProjection.generate(unsafeExprs) + createProjection(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -146,6 +145,18 @@ object UnsafeProjection { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + protected def createProjection(exprs: Seq[Expression]): UnsafeProjection +} + +object UnsafeProjection extends UnsafeProjectionCreator { + + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(exprs) + } + /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. * TODO: refactor the plumbing and clean this up. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 22717f5954a45..6682ba55b18b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -247,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); + $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); } else { $writeElement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6c9937dacc70b..f36633867316e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -85,6 +85,8 @@ case class Rand(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } + + override def freshCopy(): Rand = Rand(child) } object Rand { @@ -120,6 +122,8 @@ case class Randn(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } + + override def freshCopy(): Randn = Randn(child) } object Randn { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 84190f0bd5f7d..b4138ce366b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { null, null) } intercept[RuntimeException] { - checkEvalutionWithUnsafeProjection( + checkEvaluationWithUnsafeProjection( CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 58d0c07622eb9..c6343b1cbf600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) + checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } @@ -187,11 +187,20 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan(inputRow).get(0, expression.dataType) } - protected def checkEvalutionWithUnsafeProjection( + protected def checkEvaluationWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) + } + + protected def checkEvaluationWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow, + factory: UnsafeProjectionCreator): Unit = { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -203,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } else { val lit = InternalRow(expected, expected) val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + factory.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -213,7 +222,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow): InternalRow = { + inputRow: InternalRow = EmptyRow, + factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index ffeec2a38c532..1f6964dfef598 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -45,16 +45,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) - checkEvalutionWithUnsafeProjection( - structEncoder.serializer.head, structExpected, structInputRow) + checkEvaluationWithUnsafeProjection( + structEncoder.serializer.head, + structExpected, + structInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) - checkEvalutionWithUnsafeProjection( - arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + checkEvaluationWithUnsafeProjection( + arrayEncoder.serializer.head, + arrayExpected, + arrayInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -67,8 +73,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new ArrayBasedMapData( new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) - checkEvalutionWithUnsafeProjection( - mapEncoder.serializer.head, mapExpected, mapInputRow) + checkEvaluationWithUnsafeProjection( + mapEncoder.serializer.head, + mapExpected, + mapInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } test("SPARK-23585: UnwrapOption should support interpreted execution") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 10e3ffd0dff97..e083ae0089244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e1.getMessage.contains("Failed to execute user defined function")) val e2 = intercept[SparkException] { - checkEvalutionWithUnsafeProjection(udf, null) + checkEvaluationWithUnsafeProjection(udf, null) } assert(e2.getMessage.contains("Failed to execute user defined function")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index cf3cbe270753e..c07da122cd7b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - test("basic conversion with only primitive types") { - val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = UnsafeProjection.create(fieldTypes) + private def testWithFactory( + name: String)( + f: UnsafeProjectionCreator => Unit): Unit = { + test(name) { + f(UnsafeProjection) + f(InterpretedUnsafeProjection) + } + } + testWithFactory("basic conversion with only primitive types") { factory => + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) @@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - test("basic conversion with primitive, string and binary types") { + testWithFactory("basic conversion with primitive, string and binary types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - test("basic conversion with primitive, string, date and timestamp types") { + testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - test("null handling") { + testWithFactory("null handling") { factory => val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificInternalRow(fieldTypes) @@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - test("NaN canonicalization") { + testWithFactory("NaN canonicalization") { factory => val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - test("basic conversion with struct type") { + testWithFactory("basic conversion with struct type") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) @@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - test("basic conversion with array type") { + testWithFactory("basic conversion with array type") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) @@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - test("basic conversion with map type") { + testWithFactory("basic conversion with map type") { factory => val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val map1 = createMap(1, 2)(3, 4) @@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - test("basic conversion with struct and array") { + testWithFactory("basic conversion with struct and array") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) @@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with struct and map") { + testWithFactory("basic conversion with struct and map") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) @@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with array and map") { + testWithFactory("basic conversion with array and map") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) From 9945b0227efcd952c8e835453b2831a8c6d5d607 Mon Sep 17 00:00:00 2001 From: Ricardo Martinelli de Oliveira Date: Fri, 16 Mar 2018 10:37:11 -0700 Subject: [PATCH 175/593] [SPARK-23680] Fix entrypoint.sh to properly support Arbitrary UIDs ## What changes were proposed in this pull request? As described in SPARK-23680, entrypoint.sh returns an error code because of a command pipeline execution where it is expected in case of Openshift environments, where arbitrary UIDs are used to run containers ## How was this patch tested? This patch was manually tested by using docker-image-toll.sh script to generate a Spark driver image and running an example against an OpenShift cluster. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ricardo Martinelli de Oliveira Closes #20822 from rimolive/rmartine-spark-23680. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3d67b0a702dd4..d0cf284f035ea 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -22,7 +22,10 @@ set -ex # Check whether there is a passwd entry for the container UID myuid=$(id -u) mygid=$(id -g) +# turn off -e for getent because it will return error code in anonymous uid case +set +e uidentry=$(getent passwd $myuid) +set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then From bd201bf61e8e1713deb91b962f670c76c9e3492b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Mar 2018 11:11:07 -0700 Subject: [PATCH 176/593] [SPARK-23623][SS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? CacheKafkaConsumer in the project `kafka-0-10-sql` is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one task using trying to read the same Kafka TopicPartition at the same time. Hence, the cache was keyed by the TopicPartition a consumer is supposed to read. And any cases where this assumption may not be true, we have SparkPlan flag to disable the use of a cache. So it was up to the planner to correctly identify when it was not safe to use the cache and set the flag accordingly. Fundamentally, this is the wrong way to approach the problem. It is HARD for a high-level planner to reason about the low-level execution model, whether there will be multiple tasks in the same query trying to read the same partition. Case in point, 2.3.0 introduced stream-stream joins, and you can build a streaming self-join query on Kafka. It's pretty non-trivial to figure out how this leads to two tasks reading the same partition twice, possibly concurrently. And due to the non-triviality, it is hard to figure this out in the planner and set the flag to avoid the cache / consumer pool. And this can inadvertently lead to ConcurrentModificationException ,or worse, silent reading of incorrect data. Here is a better way to design this. The planner shouldnt have to understand these low-level optimizations. Rather the consumer pool should be smart enough avoid concurrent use of a cached consumer. Currently, it tries to do so but incorrectly (the flag inuse is not checked when returning a cached consumer, see [this](https://github.com/apache/spark/blob/master/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala#L403)). If there is another request for the same partition as a currently in-use consumer, the pool should automatically return a fresh consumer that should be closed when the task is done. Then the planner does not have to have a flag to avoid reuses. This PR is a step towards that goal. It does the following. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called KafkaConsumer is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply called `val consumer = KafkaConsumer.acquire` and then `consumer.release()`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is a concurrent attempt of the same task, then a new consumer is generated, and the existing cached consumer is marked for close upon release. - In addition, I renamed the classes because CachedKafkaConsumer is a misnomer given that what it returns may or may not be cached. This PR does not remove the planner flag to avoid reuse to make this patch safe enough for merging in branch-2.3. This can be done later in master-only. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same partition from the consumer pool. Author: Tathagata Das Closes #20767 from tdas/SPARK-23623. --- .../sql/kafka010/KafkaContinuousReader.scala | 5 +- ...Consumer.scala => KafkaDataConsumer.scala} | 242 ++++++++++++------ .../sql/kafka010/KafkaMicroBatchReader.scala | 22 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 23 +- .../kafka010/CachedKafkaConsumerSuite.scala | 34 --- .../sql/kafka010/KafkaDataConsumerSuite.scala | 124 +++++++++ 6 files changed, 295 insertions(+), 155 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{CachedKafkaConsumer.scala => KafkaDataConsumer.scala} (66%) delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 6e56b0a72d671..e7e27876088f3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -196,8 +196,7 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val consumer = - CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset @@ -245,6 +244,6 @@ class KafkaContinuousDataReader( } override def close(): Unit = { - consumer.close() + consumer.release() } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala similarity index 66% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index e97881cb0a163..48508d057a540 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -27,30 +27,73 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread +private[kafka010] sealed trait KafkaDataConsumer { + /** + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. + */ + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss) + } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange() + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + protected def internalConsumer: InternalKafkaConsumer +} + /** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected. + * This is not for direct use outside this file. */ -private[kafka010] case class CachedKafkaConsumer private( +private[kafka010] case class InternalKafkaConsumer( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { - import CachedKafkaConsumer._ + import InternalKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private var consumer = createConsumer + @volatile private var consumer = createConsumer /** indicates whether this consumer is in use or not */ - private var inuse = true + @volatile var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + @volatile var markedForClose = false /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = UNKNOWN_OFFSET + @volatile private var fetchedData = + ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -61,8 +104,6 @@ private[kafka010] case class CachedKafkaConsumer private( c } - case class AvailableOffsetRange(earliest: Long, latest: Long) - private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { case ut: UninterruptibleThread => ut.runUninterruptibly(body) @@ -313,21 +354,51 @@ private[kafka010] case class CachedKafkaConsumer private( } } -private[kafka010] object CachedKafkaConsumer extends Logging { - private val UNKNOWN_OFFSET = -2L +private[kafka010] object KafkaDataConsumer extends Logging { + + case class AvailableOffsetRange(earliest: Long, latest: Long) + + private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + assert(internalConsumer.inUse) // make sure this has been set to true + override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) } + } + + private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + override def release(): Unit = { internalConsumer.close() } + } - private case class CacheKey(groupId: String, topicPartition: TopicPartition) + private case class CacheKey(groupId: String, topicPartition: TopicPartition) { + def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = + this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) + } + // This cache has the following important properties. + // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. + // The capacity is not guaranteed to be maintained, especially when there are more active + // tasks simultaneously using consumers than the capacity. private lazy val cache = { val conf = SparkEnv.get.conf val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) - new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { - if (entry.getValue.inuse == false && this.size > capacity) { - logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + - s"removing consumer for ${entry.getKey}") + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > capacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") try { entry.getValue.close() } catch { @@ -342,80 +413,87 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } - def releaseKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - synchronized { - val consumer = cache.get(key) - if (consumer != null) { - consumer.inuse = false - } else { - logWarning(s"Attempting to release consumer that does not exist") - } - } - } - /** - * Removes (and closes) the Kafka Consumer for the given topic, partition and group id. + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by any one + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. */ - def removeKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) + def acquire( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + useCache: Boolean): KafkaDataConsumer = synchronized { + val key = new CacheKey(topicPartition, kafkaParams) + val existingInternalConsumer = cache.get(key) - synchronized { - val removedConsumer = cache.remove(key) - if (removedConsumer != null) { - removedConsumer.close() + lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams) + + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumer if any and + // start with a new one. + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + } } + cache.remove(key) // Invalidate the cache in any case + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (!useCache) { + // If planner asks to not reuse consumers, then do not use it, return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + cache.put(key, newInternalConsumer) + newInternalConsumer.inUse = true + CachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else { + // If consumer is already cached and is currently not in use, then return that consumer + existingInternalConsumer.inUse = true + CachedKafkaDataConsumer(existingInternalConsumer) } } - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def getOrCreate( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - // If this is reattempt at running the task, then invalidate cache and start with - // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { - removeKafkaConsumer(topic, partition, kafkaParams) - val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) - consumer.inuse = true - cache.put(key, consumer) - consumer - } else { - if (!cache.containsKey(key)) { - cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + private def release(intConsumer: InternalKafkaConsumer): Unit = { + synchronized { + + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams) + val cachedIntConsumer = cache.get(key) + if (intConsumer.eq(cachedIntConsumer)) { + // The released consumer is the same object as the cached one. + if (intConsumer.markedForClose) { + intConsumer.close() + cache.remove(key) + } else { + intConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + intConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache") } - val consumer = cache.get(key) - consumer.inuse = true - consumer } } +} - /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ - def createUncached( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { - new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) - } +private[kafka010] object InternalKafkaConsumer extends Logging { + + private val UNKNOWN_OFFSET = -2L private def reportDataLoss0( failOnDataLoss: Boolean, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a5f3a249b11c..2ed49ba3f5495 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -321,17 +321,8 @@ private[kafka010] case class KafkaMicroBatchDataReader( failOnDataLoss: Boolean, reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { - private val consumer = { - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We - // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } - } + private val consumer = KafkaDataConsumer.acquire( + offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -360,14 +351,7 @@ private[kafka010] case class KafkaMicroBatchDataReader( } override def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } + consumer.release() } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 66b3409c0cd04..498e344ea39f4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors @@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD( val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we - // uses `assign`, we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) - } + val consumer = KafkaDataConsumer.acquire( + sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD( } override protected def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) - } + consumer.release() } } // Release consumer, either by removing it or indicating we're no longer using it @@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD( } } - private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = { + private def resolveRange(consumer: KafkaDataConsumer, range: KafkaSourceRDDOffsetRange) = { if (range.fromOffset < 0 || range.untilOffset < 0) { // Late bind the offset range val availableOffsetRange = consumer.getAvailableOffsetRange() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala deleted file mode 100644 index 7aa7dd096c07b..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.scalatest.PrivateMethodTester - -import org.apache.spark.sql.test.SharedSQLContext - -class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { - - test("SPARK-19886: Report error cause correctly in reportDataLoss") { - val cause = new Exception("D'oh!") - val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) - val e = intercept[IllegalStateException] { - CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) - } - assert(e.getCause === cause) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..0d0fb9c3ab5af --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.PrivateMethodTester + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.ThreadUtils + +class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + protected var testUtils: KafkaTestUtils = _ + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils(Map[String, Object]()) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } + + test("SPARK-23623: concurrent use of KafkaDataConsumer") { + val topic = "topic" + Random.nextInt() + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic, 1) + testUtils.sendMessages(topic, data.toArray) + val topicPartition = new TopicPartition(topic, 0) + + import ConsumerConfig._ + val kafkaParams = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ) + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + TaskContext.setTaskContext(taskContext) + val consumer = KafkaDataConsumer.acquire( + topicPartition, kafkaParams.asJava, useCache) + try { + val range = consumer.getAvailableOffsetRange() + val rcvd = range.earliest until range.latest map { offset => + val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadpool.shutdown() + } + } +} From 8a72734f33f6a0abbd3207b0d661633c8b25d9ad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 16 Mar 2018 11:42:57 -0700 Subject: [PATCH 177/593] [SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list ## What changes were proposed in this pull request? Added a class method to construct CountVectorizerModel from a list of vocabulary strings, equivalent to the Scala version. Introduced a common param base class `_CountVectorizerParams` to allow the Python model to also own the parameters. This now matches the Scala class hierarchy. ## How was this patch tested? Added to CountVectorizer doctests to do a transform on a model constructed from vocab, and unit test to verify params and vocab are constructed correctly. Author: Bryan Cutler Closes #16770 from BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009. --- python/pyspark/ml/feature.py | 168 +++++++++++++++++++++++------------ python/pyspark/ml/tests.py | 32 ++++++- 2 files changed, 142 insertions(+), 58 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f2e357f0bede5..a1ceb7f02da8b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,12 +19,12 @@ if sys.version > '3': basestring = str -from pyspark import since, keyword_only +from pyspark import since, keyword_only, SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', @@ -403,8 +403,69 @@ def getSplits(self): return self.getOrDefault(self.splits) +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. + """ + + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0", typeConverter=TypeConverters.toFloat) + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) + + def __init__(self, *args): + super(_CountVectorizerParams, self).__init__(*args) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + + @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. @@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, >>> loadedModel = CountVectorizerModel.load(modelPath) >>> loadedModel.vocabulary == model.vocabulary True + >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], + ... inputCol="raw", outputCol="vectors") + >>> fromVocabModel.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... .. versionadded:: 1.6.0 """ - minTF = Param( - Params._dummy(), "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then this " + - "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", - typeConverter=TypeConverters.toFloat) - minDF = Param( - Params._dummy(), "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0", typeConverter=TypeConverters.toFloat) - vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", - typeConverter=TypeConverters.toInt) - binary = Param( - Params._dummy(), "binary", "Binary toggle to control the output vector values." + - " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + - " for discrete probabilistic models that model binary events rather than integer counts." + - " Default False", typeConverter=TypeConverters.toBoolean) - @keyword_only def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, outputCol=None): @@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -497,13 +544,6 @@ def setMinTF(self, value): """ return self._set(minTF=value) - @since("1.6.0") - def getMinTF(self): - """ - Gets the value of minTF or its default value. - """ - return self.getOrDefault(self.minTF) - @since("1.6.0") def setMinDF(self, value): """ @@ -511,13 +551,6 @@ def setMinDF(self, value): """ return self._set(minDF=value) - @since("1.6.0") - def getMinDF(self): - """ - Gets the value of minDF or its default value. - """ - return self.getOrDefault(self.minDF) - @since("1.6.0") def setVocabSize(self, value): """ @@ -525,13 +558,6 @@ def setVocabSize(self, value): """ return self._set(vocabSize=value) - @since("1.6.0") - def getVocabSize(self): - """ - Gets the value of vocabSize or its default value. - """ - return self.getOrDefault(self.vocabSize) - @since("2.0.0") def setBinary(self, value): """ @@ -539,24 +565,40 @@ def setBinary(self, value): """ return self._set(binary=value) - @since("2.0.0") - def getBinary(self): - """ - Gets the value of binary or its default value. - """ - return self.getOrDefault(self.binary) - def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): +@inherit_doc +class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ + @classmethod + @since("2.4.0") + def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): + """ + Construct the model directly from a vocabulary list of strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) + model = CountVectorizerModel._create_from_java_class( + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if minTF is not None: + model.setMinTF(minTF) + if binary is not None: + model.setBinary(binary) + model._set(vocabSize=len(vocabulary)) + return model + @property @since("1.6.0") def vocabulary(self): @@ -565,6 +607,20 @@ def vocabulary(self): """ return self._call_java("vocabulary") + @since("2.4.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("2.4.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6dee6938d8916..fd45fd00b270b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -679,6 +679,34 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + def test_rformula_force_index_label(self): df = self.spark.createDataFrame([ (1.0, 1.0, "a"), @@ -2019,8 +2047,8 @@ def test_java_params(self): pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and issubclass(cls, JavaParams)\ - and not inspect.isabstract(cls): + if not name.endswith('Model') and not name.endswith('Params')\ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) From 8a1efe3076f29259151f1fba2ff894487efb6c4e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 16 Mar 2018 15:40:21 -0700 Subject: [PATCH 178/593] [SPARK-23683][SQL] FileCommitProtocol.instantiate() hardening ## What changes were proposed in this pull request? With SPARK-20236, `FileCommitProtocol.instantiate()` looks for a three argument constructor, passing in the `dynamicPartitionOverwrite` parameter. If there is no such constructor, it falls back to the classic two-arg one. When `InsertIntoHadoopFsRelationCommand` passes down that `dynamicPartitionOverwrite` flag `to FileCommitProtocol.instantiate(`), it assumes that the instantiated protocol supports the specific requirements of dynamic partition overwrite. It does not notice when this does not hold, and so the output generated may be incorrect. This patch changes `FileCommitProtocol.instantiate()` so when `dynamicPartitionOverwrite == true`, it requires the protocol implementation to have a 3-arg constructor. Classic two arg constructors are supported when it is false. Also it adds some debug level logging for anyone trying to understand what's going on. ## How was this patch tested? Unit tests verify that * classes with only 2-arg constructor cannot be used with dynamic overwrite * classes with only 2-arg constructor can be used without dynamic overwrite * classes with 3 arg constructors can be used with both. * the fallback to any two arg ctor takes place after the attempt to load the 3-arg ctor, * passing in invalid class types fail as expected (regression tests on expected behavior) Author: Steve Loughran Closes #20824 from steveloughran/stevel/SPARK-23683-protocol-instantiate. --- .../internal/io/FileCommitProtocol.scala | 11 +- ...FileCommitProtocolInstantiationSuite.scala | 148 ++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 6d0059b6a0272..e6e9c9e328853 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import org.apache.hadoop.fs._ import org.apache.hadoop.mapreduce._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -132,7 +133,7 @@ abstract class FileCommitProtocol { } -object FileCommitProtocol { +object FileCommitProtocol extends Logging { class TaskCommitMessage(val obj: Any) extends Serializable object EmptyTaskCommitMessage extends TaskCommitMessage(null) @@ -145,15 +146,23 @@ object FileCommitProtocol { jobId: String, outputPath: String, dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { + + logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + + s" dynamic=$dynamicPartitionOverwrite") val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] // First try the constructor with arguments (jobId: String, outputPath: String, // dynamicPartitionOverwrite: Boolean). // If that doesn't exist, try the one with (jobId: string, outputPath: String). try { val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + logDebug("Using (String, String, Boolean) constructor") ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) } catch { case _: NoSuchMethodException => + logDebug("Falling back to (String, String) constructor") + require(!dynamicPartitionOverwrite, + "Dynamic Partition Overwrite is enabled but" + + s" the committer ${className} does not have the appropriate constructor") val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) ctor.newInstance(jobId, outputPath) } diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala new file mode 100644 index 0000000000000..2bd32fc927e21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for instantiation of FileCommitProtocol implementations. + */ +class FileCommitProtocolInstantiationSuite extends SparkFunSuite { + + test("Dynamic partitions require appropriate constructor") { + + // you cannot instantiate a two-arg client with dynamic partitions + // enabled. + val ex = intercept[IllegalArgumentException] { + instantiateClassic(true) + } + // check the contents of the message and rethrow if unexpected. + // this preserves the stack trace of the unexpected + // exception. + if (!ex.toString.contains("Dynamic Partition Overwrite")) { + fail(s"Wrong text in caught exception $ex", ex) + } + } + + test("Standard partitions work with classic constructor") { + instantiateClassic(false) + } + + test("Three arg constructors have priority") { + assert(3 == instantiateNew(false).argCount, + "Wrong constructor argument count") + } + + test("Three arg constructors have priority when dynamic") { + assert(3 == instantiateNew(true).argCount, + "Wrong constructor argument count") + } + + test("The protocol must be of the correct class") { + intercept[ClassCastException] { + FileCommitProtocol.instantiate( + classOf[Other].getCanonicalName, + "job", + "path", + false) + } + } + + test("If there is no matching constructor, class hierarchy is irrelevant") { + intercept[NoSuchMethodException] { + FileCommitProtocol.instantiate( + classOf[NoMatchingArgs].getCanonicalName, + "job", + "path", + false) + } + } + + /** + * Create a classic two-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[ClassicConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[ClassicConstructorCommitProtocol] + } + + /** + * Create a three-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateNew( + dynamic: Boolean): FullConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[FullConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[FullConstructorCommitProtocol] + } + +} + +/** + * This protocol implementation does not have the new three-arg + * constructor. + */ +private class ClassicConstructorCommitProtocol(arg1: String, arg2: String) + extends HadoopMapReduceCommitProtocol(arg1, arg2) { +} + +/** + * This protocol implementation does have the new three-arg constructor + * alongside the original, and a 4 arg one for completeness. + * The final value of the real constructor is the number of arguments + * used in the 2- and 3- constructor, for test assertions. + */ +private class FullConstructorCommitProtocol( + arg1: String, + arg2: String, + b: Boolean, + val argCount: Int) + extends HadoopMapReduceCommitProtocol(arg1, arg2, b) { + + def this(arg1: String, arg2: String) = { + this(arg1, arg2, false, 2) + } + + def this(arg1: String, arg2: String, b: Boolean) = { + this(arg1, arg2, false, 3) + } +} + +/** + * This has the 2-arity constructor, but isn't the right class. + */ +private class Other(arg1: String, arg2: String) { + +} + +/** + * This has no matching arguments as well as being the wrong class. + */ +private class NoMatchingArgs() { + +} + From 61487b308b0169e3108c2ad31674a0c80b8ac5f3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Mar 2018 20:24:14 +0900 Subject: [PATCH 179/593] [SPARK-23706][PYTHON] spark.conf.get(value, default=None) should produce None in PySpark ## What changes were proposed in this pull request? Scala: ``` scala> spark.conf.get("hey", null) res1: String = null ``` ``` scala> spark.conf.get("spark.sql.sources.partitionOverwriteMode", null) res2: String = null ``` Python: **Before** ``` >>> spark.conf.get("hey", None) ... py4j.protocol.Py4JJavaError: An error occurred while calling o30.get. : java.util.NoSuchElementException: hey ... ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) u'STATIC' ``` **After** ``` >>> spark.conf.get("hey", None) is None True ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) is None True ``` *Note that this PR preserves the case below: ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode") u'STATIC' ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20841 from HyukjinKwon/spark-conf-get. --- python/pyspark/sql/conf.py | 9 +++++---- python/pyspark/sql/context.py | 8 ++++---- python/pyspark/sql/tests.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index d929834aeeaa5..b82224b6194ed 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -17,7 +17,7 @@ import sys -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix @@ -39,15 +39,16 @@ def set(self, key, value): @ignore_unicode_prefix @since(2.0) - def get(self, key, default=None): + def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. """ self._checkType(key, "key") - if default is None: + if default is _NoValue: return self._jconf.get(key) else: - self._checkType(default, "default") + if default is not None: + self._checkType(default, "default") return self._jconf.get(key, default) @ignore_unicode_prefix diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6cb90399dd616..e9ec7ba866761 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -22,7 +22,7 @@ if sys.version >= '3': basestring = unicode = str -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame @@ -124,11 +124,11 @@ def setConf(self, key, value): @ignore_unicode_prefix @since(1.3) - def getConf(self, key, defaultValue=None): + def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. - If the key is not set and defaultValue is not None, return - defaultValue. If the key is not set and defaultValue is None, return + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 480815d27333f..a0d547ad620e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2504,6 +2504,17 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + def test_current_database(self): spark = self.spark spark.catalog._reset() From 745c8c0901ac522ba92c1356ca74bd0dd7701496 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Mon, 19 Mar 2018 13:31:21 +0800 Subject: [PATCH 180/593] [SPARK-23708][CORE] Correct comment for function addShutDownHook in ShutdownHookManager ## What changes were proposed in this pull request? Minor modification.Comment below is not right. ``` /** * Adds a shutdown hook with the given priority. Hooks with lower priority values run * first. * * param hook The code to run during shutdown. * return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } ``` ## How was this patch tested? UT Author: zhoukang Closes #20845 from caneGuy/zhoukang/fix-shutdowncomment. --- .../main/scala/org/apache/spark/util/ShutdownHookManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4001fac3c3d5a..b702838fa257f 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -143,7 +143,7 @@ private[spark] object ShutdownHookManager extends Logging { } /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * Adds a shutdown hook with the given priority. Hooks with higher priority values run * first. * * @param hook The code to run during shutdown. From 4de638c1976dea74761bbe5c30da808178ee885d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Mar 2018 09:41:43 +0100 Subject: [PATCH 181/593] [SPARK-23599][SQL] Add a UUID generator from Pseudo-Random Numbers ## What changes were proposed in this pull request? This patch adds a UUID generator from Pseudo-Random Numbers. We can use it later to have deterministic `UUID()` expression. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20817 from viirya/SPARK-23599. --- .../catalyst/util/RandomUUIDGenerator.scala | 43 ++++++++++++++ .../util/RandomUUIDGeneratorSuite.scala | 57 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala new file mode 100644 index 0000000000000..4fe07a071c1ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.UUID + +import org.apache.commons.math3.random.MersenneTwister + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class is used to generate a UUID from Pseudo-Random Numbers. + * + * For the algorithm, see RFC 4122: A Universally Unique IDentifier (UUID) URN Namespace, + * section 4.4 "Algorithms for Creating a UUID from Truly Random or Pseudo-Random Numbers". + */ +case class RandomUUIDGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextUUID(): UUID = { + val mostSigBits = (random.nextLong() & 0xFFFFFFFFFFFF0FFFL) | 0x0000000000004000L + val leastSigBits = (random.nextLong() | 0x8000000000000000L) & 0xBFFFFFFFFFFFFFFFL + + new UUID(mostSigBits, leastSigBits) + } + + def getNextUUIDUTF8String(): UTF8String = UTF8String.fromString(getNextUUID().toString()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala new file mode 100644 index 0000000000000..b75739e5a3a65 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + +class RandomUUIDGeneratorSuite extends SparkFunSuite { + test("RandomUUIDGenerator should generate version 4, variant 2 UUIDs") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + for (_ <- 0 to 100) { + val uuid = generator.getNextUUID() + assert(uuid.version() == 4) + assert(uuid.variant() == 2) + } + } + + test("UUID from RandomUUIDGenerator should be deterministic") { + val r1 = new Random(100) + val generator1 = RandomUUIDGenerator(r1.nextLong()) + val r2 = new Random(100) + val generator2 = RandomUUIDGenerator(r2.nextLong()) + val r3 = new Random(101) + val generator3 = RandomUUIDGenerator(r3.nextLong()) + + for (_ <- 0 to 100) { + val uuid1 = generator1.getNextUUID() + val uuid2 = generator2.getNextUUID() + val uuid3 = generator3.getNextUUID() + assert(uuid1 == uuid2) + assert(uuid1 != uuid3) + } + } + + test("Get UTF8String UUID") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + val utf8StringUUID = generator.getNextUUIDUTF8String() + val uuid = java.util.UUID.fromString(utf8StringUUID.toString) + assert(uuid.version() == 4 && uuid.variant() == 2 && utf8StringUUID.toString == uuid.toString) + } +} From f15906da153f139b698e192ec6f82f078f896f1e Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Mon, 19 Mar 2018 11:29:56 -0700 Subject: [PATCH 182/593] [SPARK-22839][K8S] Remove the use of init-container for downloading remote dependencies ## What changes were proposed in this pull request? Removal of the init-container for downloading remote dependencies. Built off of the work done by vanzin in an attempt to refactor driver/executor configuration elaborated in [this](https://issues.apache.org/jira/browse/SPARK-22839) ticket. ## How was this patch tested? This patch was tested with unit and integration tests. Author: Ilan Filonenko Closes #20669 from ifilonenko/remove-init-container. --- bin/docker-image-tool.sh | 9 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 - docs/running-on-kubernetes.md | 71 +------- .../spark/examples/SparkRemoteFileTest.scala | 48 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 73 +------- .../apache/spark/deploy/k8s/Constants.scala | 21 +-- .../deploy/k8s/InitContainerBootstrap.scala | 120 ------------- .../spark/deploy/k8s/KubernetesUtils.scala | 63 +------ .../k8s/PodWithDetachedInitContainer.scala | 31 ---- .../deploy/k8s/SparkPodInitContainer.scala | 116 ------------- .../k8s/submit/DriverConfigOrchestrator.scala | 45 +---- .../submit/KubernetesClientApplication.scala | 84 +++++---- .../steps/BasicDriverConfigurationStep.scala | 32 ++-- .../steps/DependencyResolutionStep.scala | 18 +- .../DriverInitContainerBootstrapStep.scala | 95 ----------- .../DriverKubernetesCredentialsStep.scala | 2 +- .../BasicInitContainerConfigurationStep.scala | 67 -------- .../InitContainerConfigOrchestrator.scala | 79 --------- .../InitContainerConfigurationStep.scala | 25 --- .../InitContainerMountSecretsStep.scala | 36 ---- .../initcontainer/InitContainerSpec.scala | 37 ---- .../cluster/k8s/ExecutorPodFactory.scala | 43 +---- .../k8s/KubernetesClusterManager.scala | 65 +------ .../k8s/SparkPodInitContainerSuite.scala | 86 ---------- .../spark/deploy/k8s/submit/ClientSuite.scala | 82 ++++----- .../DriverConfigOrchestratorSuite.scala | 41 +---- .../BasicDriverConfigurationStepSuite.scala | 8 +- .../steps/DependencyResolutionStepSuite.scala | 32 ++-- ...riverInitContainerBootstrapStepSuite.scala | 160 ------------------ ...cInitContainerConfigurationStepSuite.scala | 95 ----------- ...InitContainerConfigOrchestratorSuite.scala | 80 --------- .../InitContainerMountSecretsStepSuite.scala | 52 ------ .../cluster/k8s/ExecutorPodFactorySuite.scala | 67 +------- .../src/main/dockerfiles/spark/Dockerfile | 1 - .../src/main/dockerfiles/spark/entrypoint.sh | 20 +-- 35 files changed, 241 insertions(+), 1665 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 0d0f564bb8b9b..f090240065bf1 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -64,9 +64,11 @@ function build { error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi + local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + docker build "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$IMG_PATH/spark/Dockerfile" . + -f "$DOCKERFILE" . } function push { @@ -84,6 +86,7 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: + -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -113,10 +116,12 @@ fi REPO= TAG= -while getopts mr:t: option +DOCKERFILE= +while getopts f:mr:t: option do case "${option}" in + f) DOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; m) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e381965c52ba..329bde08718fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -320,8 +320,6 @@ object SparkSubmit extends CommandLineUtils with Logging { printErrorAndExit("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => printErrorAndExit("R applications are currently not supported for Kubernetes.") - case (KUBERNETES, CLIENT) => - printErrorAndExit("Client mode is currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 3c7586e8544ba..975b28de47e20 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -126,29 +126,6 @@ Those dependencies can be added to the classpath by referencing them with `local dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission client's local file system is currently not yet supported. - -### Using Remote Dependencies -When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods -need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. - -The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and -`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., -the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: - -```bash -$ bin/spark-submit \ - --master k8s://https://: \ - --deploy-mode cluster \ - --name spark-pi \ - --class org.apache.spark.examples.SparkPi \ - --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar - --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 - --conf spark.executor.instances=5 \ - --conf spark.kubernetes.container.image= \ - https://path/to/examples.jar -``` - ## Secret Management Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a Spark application to access secured services. To mount a user-specified secret into the driver container, users can use @@ -163,10 +140,6 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` -Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the -init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the -init-container of the executor. - ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -604,51 +577,12 @@ specific to Spark on Kubernetes. the Driver process. The user can specify multiple of these to set multiple environment variables. - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.timeout - 300s - - Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into - the driver and executor pods. - - - - spark.kubernetes.mountDependencies.maxSimultaneousDownloads - 5 - - Maximum number of remote dependencies to download simultaneously in a driver or executor pod. - - - - spark.kubernetes.initContainer.image - (value of spark.kubernetes.container.image) - - Custom container image for the init container of both driver and executors. - - spark.kubernetes.driver.secrets.[SecretName] (none) Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, - spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the driver pod. + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. @@ -656,8 +590,7 @@ specific to Spark on Kubernetes. (none) Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, - spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the executor pod. + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. \ No newline at end of file diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala new file mode 100644 index 0000000000000..64076f2deb706 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import org.apache.spark.SparkFiles +import org.apache.spark.sql.SparkSession + +/** Usage: SparkRemoteFileTest [file] */ +object SparkRemoteFileTest { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: SparkRemoteFileTest ") + System.exit(1) + } + val spark = SparkSession + .builder() + .appName("SparkRemoteFileTest") + .getOrCreate() + val sc = spark.sparkContext + val rdd = sc.parallelize(Seq(1)).map(_ => { + val localLocation = SparkFiles.get(args(0)) + println(s"${args(0)} is stored at: $localLocation") + new File(localLocation).isFile + }) + val truthCheck = rdd.collect().head + println(s"Mounting of ${args(0)} was $truthCheck") + spark.stop() + } +} +// scalastyle:on println diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 471196ac0e3f6..da34a7e06238a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -79,6 +79,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_DRIVER_SUBMIT_CHECK = + ConfigBuilder("spark.kubernetes.submitInDriver") + .internal() + .booleanConf + .createOptional + val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") .doc("Specify the hard cpu limit for each executor pod") @@ -135,73 +141,6 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") - val JARS_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pod.") - .stringConf - .createWithDefault("/var/spark-data/spark-jars") - - val FILES_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pods.") - .stringConf - .createWithDefault("/var/spark-data/spark-files") - - val INIT_CONTAINER_IMAGE = - ConfigBuilder("spark.kubernetes.initContainer.image") - .doc("Image for the driver and executor's init-container for downloading dependencies.") - .fallbackConf(CONTAINER_IMAGE) - - val INIT_CONTAINER_MOUNT_TIMEOUT = - ConfigBuilder("spark.kubernetes.mountDependencies.timeout") - .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + - "locations into the driver and executor pods.") - .timeConf(TimeUnit.SECONDS) - .createWithDefault(300) - - val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = - ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") - .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + - "executor pod.") - .intConf - .createWithDefault(5) - - val INIT_CONTAINER_REMOTE_JARS = - ConfigBuilder("spark.kubernetes.initContainer.remoteJars") - .doc("Comma-separated list of jar URIs to download in the init-container. This is " + - "calculated from spark.jars.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_REMOTE_FILES = - ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") - .doc("Comma-separated list of file URIs to download in the init-container. This is " + - "calculated from spark.files.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_NAME = - ConfigBuilder("spark.kubernetes.initContainer.configMapName") - .doc("Name of the config map to use in the init-container that retrieves submitted files " + - "for the executor.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = - ConfigBuilder("spark.kubernetes.initContainer.configMapKey") - .doc("Key for the entry in the init container config map for submitted files that " + - "corresponds to the properties for this init-container.") - .internal() - .stringConf - .createOptional - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 9411956996843..8da5f24044aad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -63,22 +63,13 @@ private[spark] object Constants { val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" - val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" - val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" - val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" - val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" - val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" - - // Bootstrapping dependencies with the init-container - val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" - val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" - val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" - val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" - val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" - val INIT_CONTAINER_PROPERTIES_FILE_PATH = - s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" - val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" + val ENV_SPARK_CONF_DIR = "SPARK_CONF_DIR" + // Spark app configs for containers + val SPARK_CONF_VOLUME = "spark-conf-volume" + val SPARK_CONF_DIR_INTERNAL = "/opt/spark/conf" + val SPARK_CONF_FILE_NAME = "spark.properties" + val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala deleted file mode 100644 index f6a57dfe00171..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Bootstraps an init-container for downloading remote dependencies. This is separated out from - * the init-container steps API because this component can be used to bootstrap init-containers - * for both the driver and executors. - */ -private[spark] class InitContainerBootstrap( - initContainerImage: String, - imagePullPolicy: String, - jarsDownloadPath: String, - filesDownloadPath: String, - configMapName: String, - configMapKey: String, - sparkRole: String, - sparkConf: SparkConf) { - - /** - * Bootstraps an init-container that downloads dependencies to be used by a main container. - */ - def bootstrapInitContainer( - original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { - val sharedVolumeMounts = Seq[VolumeMount]( - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withMountPath(jarsDownloadPath) - .build(), - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withMountPath(filesDownloadPath) - .build()) - - val customEnvVarKeyPrefix = sparkRole match { - case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY - case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." - case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") - } - val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { - case (key, value) => - new EnvVarBuilder() - .withName(key) - .withValue(value) - .build() - } - - val initContainer = new ContainerBuilder(original.initContainer) - .withName("spark-init") - .withImage(initContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(customEnvVars.asJava) - .addNewVolumeMount() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) - .endVolumeMount() - .addToVolumeMounts(sharedVolumeMounts: _*) - .addToArgs("init") - .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) - .build() - - val podWithBasicVolumes = new PodBuilder(original.pod) - .editSpec() - .addNewVolume() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withNewConfigMap() - .withName(configMapName) - .addNewItem() - .withKey(configMapKey) - .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .endSpec() - .build() - - val mainContainer = new ContainerBuilder(original.mainContainer) - .addToVolumeMounts(sharedVolumeMounts: _*) - .addNewEnv() - .withName(ENV_MOUNTED_FILES_DIR) - .withValue(filesDownloadPath) - .endEnv() - .build() - - PodWithDetachedInitContainer( - podWithBasicVolumes, - initContainer, - mainContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 37331d8bbf9b7..5bc070147d3a8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,10 +16,6 @@ */ package org.apache.spark.deploy.k8s -import java.io.File - -import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} - import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -43,72 +39,23 @@ private[spark] object KubernetesUtils { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } - /** - * Append the given init-container to a pod's list of init-containers. - * - * @param originalPodSpec original specification of the pod - * @param initContainer the init-container to add to the pod - * @return the pod with the init-container added to the list of InitContainers - */ - def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { - new PodBuilder(originalPodSpec) - .editOrNewSpec() - .addToInitContainers(initContainer) - .endSpec() - .build() - } - /** * For the given collection of file URIs, resolves them as follows: - * - File URIs with scheme file:// are resolved to the given download path. * - File URIs with scheme local:// resolve to just the path of the URI. * - Otherwise, the URIs are returned as-is. */ - def resolveFileUris( - fileUris: Iterable[String], - fileDownloadPath: String): Iterable[String] = { - fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, false) - } - } - - /** - * If any file uri has any scheme other than local:// it is mapped as if the file - * was downloaded to the file download path. Otherwise, it is mapped to the path - * part of the URI. - */ - def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = { + def resolveFileUrisAndPath(fileUris: Iterable[String]): Iterable[String] = { fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, true) - } - } - - /** - * Get from a given collection of file URIs the ones that represent remote files. - */ - def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { - uris.filter { uri => - val scheme = Utils.resolveURI(uri).getScheme - scheme != "file" && scheme != "local" + resolveFileUri(uri) } } - private def resolveFileUri( - uri: String, - fileDownloadPath: String, - assumesDownloaded: Boolean): String = { + private def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { - case "local" => - fileUri.getPath - case _ => - if (assumesDownloaded || fileScheme == "file") { - val fileName = new File(fileUri.getPath).getName - s"$fileDownloadPath/$fileName" - } else { - uri - } + case "local" => fileUri.getPath + case _ => uri } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala deleted file mode 100644 index 0b79f8b12e806..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, Pod} - -/** - * Represents a pod with a detached init-container (not yet added to the pod). - * - * @param pod the pod - * @param initContainer the init-container in the pod - * @param mainContainer the main container in the pod - */ -private[spark] case class PodWithDetachedInitContainer( - pod: Pod, - initContainer: Container, - mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala deleted file mode 100644 index c0f08786b76a1..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.concurrent.TimeUnit - -import scala.concurrent.{ExecutionContext, Future} - -import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -/** - * Process that fetches files from a resource staging server and/or arbitrary remote locations. - * - * The init-container can handle fetching files from any of those sources, but not all of the - * sources need to be specified. This allows for composing multiple instances of this container - * with different configurations for different download sources, or using the same container to - * download everything at once. - */ -private[spark] class SparkPodInitContainer( - sparkConf: SparkConf, - fileFetcher: FileFetcher) extends Logging { - - private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) - private implicit val downloadExecutor = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) - - private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) - private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) - - private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) - private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) - - private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) - - def run(): Unit = { - logInfo(s"Downloading remote jars: $remoteJars") - downloadFiles( - remoteJars, - jarsDownloadDir, - s"Remote jars download directory specified at $jarsDownloadDir does not exist " + - "or is not a directory.") - - logInfo(s"Downloading remote files: $remoteFiles") - downloadFiles( - remoteFiles, - filesDownloadDir, - s"Remote files download directory specified at $filesDownloadDir does not exist " + - "or is not a directory.") - - downloadExecutor.shutdown() - downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) - } - - private def downloadFiles( - filesCommaSeparated: Option[String], - downloadDir: File, - errMessage: String): Unit = { - filesCommaSeparated.foreach { files => - require(downloadDir.isDirectory, errMessage) - Utils.stringToSeq(files).foreach { file => - Future[Unit] { - fileFetcher.fetchFile(file, downloadDir) - } - } - } - } -} - -private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { - - def fetchFile(uri: String, targetDir: File): Unit = { - Utils.fetchFile( - url = uri, - targetDir = targetDir, - conf = sparkConf, - securityMgr = securityManager, - hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), - timestamp = System.currentTimeMillis(), - useCache = false) - } -} - -object SparkPodInitContainer extends Logging { - - def main(args: Array[String]): Unit = { - logInfo("Starting init-container to download Spark application dependencies.") - val sparkConf = new SparkConf(true) - if (args.nonEmpty) { - Utils.loadDefaultSparkProperties(sparkConf, args(0)) - } - - val securityManager = new SparkSecurityManager(sparkConf) - val fileFetcher = new FileFetcher(sparkConf, securityManager) - new SparkPodInitContainer(sparkConf, fileFetcher).run() - logInfo("Finished downloading application dependencies.") - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index ae70904621184..b4d3f04a1bc32 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -16,16 +16,11 @@ */ package org.apache.spark.deploy.k8s.submit -import java.util.UUID - -import com.google.common.primitives.Longs - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils @@ -34,13 +29,11 @@ import org.apache.spark.util.Utils * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to * configure the Spark driver pod. The returned steps will be applied one by one in the given * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to - * configure the driver init-container if one is needed, i.e., when there are remote dependencies - * to localize. + * to construct and create the driver pod. */ private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, - launchTime: Long, + kubernetesResourceNamePrefix: String, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, @@ -50,15 +43,8 @@ private[spark] class DriverConfigOrchestrator( // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the // application the user submitted. - private val kubernetesResourceNamePrefix = { - val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "") - s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") - } private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" - private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -126,9 +112,7 @@ private[spark] class DriverConfigOrchestrator( val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath)) + sparkFiles)) } else { Nil } @@ -139,33 +123,12 @@ private[spark] class DriverConfigOrchestrator( Nil } - val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { - val orchestrator = new InitContainerConfigOrchestrator( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - imagePullPolicy, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME, - sparkConf) - val bootstrapStep = new DriverInitContainerBootstrapStep( - orchestrator.getAllConfigurationSteps, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME) - - Seq(bootstrapStep) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - mountSecretsStep ++ - initContainerBootstrapStep + mountSecretsStep } private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 5884348cb3e41..e16d1add600b2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -16,14 +16,14 @@ */ package org.apache.spark.deploy.k8s.submit +import java.io.StringWriter import java.util.{Collections, UUID} - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.util.control.NonFatal +import java.util.Properties import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication @@ -32,6 +32,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -93,10 +94,8 @@ private[spark] class Client( kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - watcher: LoggingPodStatusWatcher) extends Logging { - - private val driverJavaOptions = sparkConf.get( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + watcher: LoggingPodStatusWatcher, + kubernetesResourceNamePrefix: String) extends Logging { /** * Run command that initializes a DriverSpec that will be updated after each @@ -110,33 +109,31 @@ private[spark] class Client( for (nextStep <- submissionSteps) { currentDriverSpec = nextStep.configureDriver(currentDriverSpec) } - - val resolvedDriverJavaOpts = currentDriverSpec - .driverSparkConf - // Remove this as the options are instead extracted and set individually below using - // environment variables with prefix SPARK_JAVA_OPT_. - .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) - .getAll - .map { - case (confKey, confValue) => s"-D$confKey=$confValue" - } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) - val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map { - case (option, index) => - new EnvVarBuilder() - .withName(s"$ENV_JAVA_OPT_PREFIX$index") - .withValue(option) - .build() - } - + val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" + val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the + // Spark command builder to pickup on the Java Options present in the ConfigMap val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) - .addAllToEnv(driverJavaOptsEnvs.asJava) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() .build() val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) .editSpec() .addToContainers(resolvedDriverContainer) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .endConfigMap() + .endVolume() .endSpec() .build() - Utils.tryWithResource( kubernetesClient .pods() @@ -145,7 +142,8 @@ private[spark] class Client( val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = currentDriverSpec.otherKubernetesResources + val otherKubernetesResources = + currentDriverSpec.otherKubernetesResources ++ Seq(configMap) addDriverOwnerReference(createdDriverPod, otherKubernetesResources) kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } @@ -180,6 +178,26 @@ private[spark] class Client( originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference)) } } + + // Build a Config Map that will house spark conf properties in a single file for spark-submit + private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + val properties = new Properties() + conf.getAll.foreach { case (k, v) => + properties.setProperty(k, v) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName") + + val namespace = conf.get(KUBERNETES_NAMESPACE) + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .withNamespace(namespace) + .endMetadata() + .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) + .build() + } } /** @@ -202,6 +220,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") + val kubernetesResourceNamePrefix = { + s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") + } // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -211,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, - launchTime, + kubernetesResourceNamePrefix, clientArguments.mainAppResource, appName, clientArguments.mainClass, @@ -231,7 +252,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesClient, waitForAppCompletion, appName, - watcher) + watcher, + kubernetesResourceNamePrefix) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 164e2e5594778..347c4d2d66826 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} +import org.apache.spark.launcher.SparkLauncher /** * Performs basic configuration for the driver pod. @@ -56,8 +57,6 @@ private[spark] class BasicDriverConfigurationStep( // Memory settings private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val driverMemoryString = sparkConf.get( - DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) @@ -103,24 +102,12 @@ private[spark] class BasicDriverConfigurationStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } - val driverContainer = new ContainerBuilder(driverSpec.driverContainer) + val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(imagePullPolicy) .addAllToEnv(driverCustomEnvs.asJava) .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_MEMORY) - .withValue(driverMemoryString) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_MAIN_CLASS) - .withValue(mainClass) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.mkString(" ")) - .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) .withValueFrom(new EnvVarSourceBuilder() @@ -134,7 +121,16 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") - .build() + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + + val driverContainer = appArgs.toList match { + case "" :: Nil | Nil => driverContainerWithoutArgs.build() + case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() + } val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() @@ -152,10 +148,14 @@ private[spark] class BasicDriverConfigurationStep( .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) + // to set the config variables to allow client-mode spark-submit from driver + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) driverSpec.copy( driverPod = baseDriverPod, driverSparkConf = resolvedSparkConf, driverContainer = driverContainer) } + } + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index d4b83235b4e3b..43de329f239ad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -30,13 +30,11 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec */ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String) extends DriverConfigurationStep { + sparkFiles: Seq[String]) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { @@ -45,14 +43,12 @@ private[spark] class DependencyResolutionStep( if (resolvedSparkFiles.nonEmpty) { sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - - val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { + val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedClasspath.mkString(File.pathSeparator)) - .endEnv() + .withName(ENV_MOUNTED_CLASSPATH) + .withValue(resolvedSparkJars.mkString(File.pathSeparator)) + .endEnv() .build() } else { driverSpec.driverContainer diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala deleted file mode 100644 index 9fb3dafdda540..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringWriter -import java.util.Properties - -import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} - -/** - * Configures the driver init-container that localizes remote dependencies into the driver pod. - * It applies the given InitContainerConfigurationSteps in the given order to produce a final - * InitContainerSpec that is then used to configure the driver pod with the init-container attached. - * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries - * configuration properties for the init-container. - */ -private[spark] class DriverInitContainerBootstrapStep( - steps: Seq[InitContainerConfigurationStep], - configMapName: String, - configMapKey: String) - extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - var initContainerSpec = InitContainerSpec( - properties = Map.empty[String, String], - driverSparkConf = Map.empty[String, String], - initContainer = new ContainerBuilder().build(), - driverContainer = driverSpec.driverContainer, - driverPod = driverSpec.driverPod, - dependentResources = Seq.empty[HasMetadata]) - for (nextStep <- steps) { - initContainerSpec = nextStep.configureInitContainer(initContainerSpec) - } - - val configMap = buildConfigMap( - configMapName, - configMapKey, - initContainerSpec.properties) - val resolvedDriverSparkConf = driverSpec.driverSparkConf - .clone() - .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) - .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) - .setAll(initContainerSpec.driverSparkConf) - val resolvedDriverPod = KubernetesUtils.appendInitContainer( - initContainerSpec.driverPod, initContainerSpec.initContainer) - - driverSpec.copy( - driverPod = resolvedDriverPod, - driverContainer = initContainerSpec.driverContainer, - driverSparkConf = resolvedDriverSparkConf, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ - initContainerSpec.dependentResources ++ - Seq(configMap)) - } - - private def buildConfigMap( - configMapName: String, - configMapKey: String, - config: Map[String, String]): ConfigMap = { - val properties = new Properties() - config.foreach { entry => - properties.setProperty(entry._1, entry._2) - } - val propertiesWriter = new StringWriter() - properties.store(propertiesWriter, - s"Java properties built from Kubernetes config map with name: $configMapName " + - s"and config map key: $configMapKey") - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(configMapKey, propertiesWriter.toString) - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala index ccc18908658f1..2424e63999a82 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala @@ -99,7 +99,7 @@ private[spark] class DriverKubernetesCredentialsStep( }.getOrElse(driverSpec.driverPod) ) - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret => + val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => new ContainerBuilder(driverSpec.driverContainer) .addNewVolumeMount() .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala deleted file mode 100644 index 01469853dacc2..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils - -/** - * Performs basic configuration for the driver init-container with most of the work delegated to - * the given InitContainerBootstrap. - */ -private[spark] class BasicInitContainerConfigurationStep( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - bootstrap: InitContainerBootstrap) - extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { - val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) - val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) - val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) - } else { - Map() - } - val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) - } else { - Map() - } - - val baseInitContainerConfig = Map( - JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, - FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ - remoteJarsConf ++ - remoteFilesConf - - val bootstrapped = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - spec.driverPod, - spec.initContainer, - spec.driverContainer)) - - spec.copy( - initContainer = bootstrapped.initContainer, - driverContainer = bootstrapped.mainContainer, - driverPod = bootstrapped.pod, - properties = spec.properties ++ baseInitContainerConfig) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala deleted file mode 100644 index f2c29c7ce1076..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to - * configure the driver init-container. The returned steps will be applied in the given order to - * produce a final InitContainerSpec that is used to construct the driver init-container in - * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., - * when there are remote application dependencies to localize. - */ -private[spark] class InitContainerConfigOrchestrator( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - imagePullPolicy: String, - configMapName: String, - configMapKey: String, - sparkConf: SparkConf) { - - private val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - - def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { - val initContainerBootstrap = new InitContainerBootstrap( - initContainerImage, - imagePullPolicy, - jarsDownloadPath, - filesDownloadPath, - configMapName, - configMapKey, - SPARK_POD_DRIVER_ROLE, - sparkConf) - val baseStep = new BasicInitContainerConfigurationStep( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - initContainerBootstrap) - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - // Mount user-specified driver secrets also into the driver's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The driver's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq(baseStep) ++ mountSecretsStep - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala deleted file mode 100644 index 0372ad5270951..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -/** - * Represents a step in configuring the driver init-container. - */ -private[spark] trait InitContainerConfigurationStep { - - def configureInitContainer(spec: InitContainerSpec): InitContainerSpec -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala deleted file mode 100644 index 0daa7b95e8aae..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -/** - * An init-container configuration step for mounting user-specified secrets onto user-specified - * paths. - * - * @param bootstrap a utility actually handling mounting of the secrets - */ -private[spark] class InitContainerMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - // Mount the secret volumes given that the volumes have already been added to the driver pod - // when mounting the secrets into the main driver container. - val initContainer = bootstrap.mountSecrets(spec.initContainer) - spec.copy(initContainer = initContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala deleted file mode 100644 index b52c343f0c0ed..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} - -/** - * Represents a specification of the init-container for the driver pod. - * - * @param properties properties that should be set on the init-container - * @param driverSparkConf Spark configuration properties that will be carried back to the driver - * @param initContainer the init-container object - * @param driverContainer the driver container object - * @param driverPod the driver pod object - * @param dependentResources resources the init-container depends on to work - */ -private[spark] case class InitContainerSpec( - properties: Map[String, String], - driverSparkConf: Map[String, String], - initContainer: Container, - driverContainer: Container, - driverPod: Pod, - dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 141bd2827e7c5..98cbd5607da00 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -34,18 +34,10 @@ import org.apache.spark.util.Utils * @param sparkConf Spark configuration * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto * user-specified paths into the executor container - * @param initContainerBootstrap an optional component for bootstrapping the executor init-container - * if one is needed, i.e., when there are remote dependencies to - * localize - * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified - * secrets onto user-specified paths into the executor - * init-container */ private[spark] class ExecutorPodFactory( sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap], - initContainerBootstrap: Option[InitContainerBootstrap], - initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { + mountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) @@ -94,8 +86,6 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) - /** * Configure and construct an executor pod with the given parameters. */ @@ -147,8 +137,9 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId), - (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) @@ -221,30 +212,10 @@ private[spark] class ExecutorPodFactory( (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) - val (bootstrappedPod, bootstrappedContainer) = - initContainerBootstrap.map { bootstrap => - val podWithInitContainer = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - maybeSecretsMountedPod, - new ContainerBuilder().build(), - maybeSecretsMountedContainer)) - - val (pod, mayBeSecretsMountedInitContainer) = - initContainerMountSecretsBootstrap.map { bootstrap => - // Mount the secret volumes given that the volumes have already been added to the - // executor pod when mounting the secrets into the main executor container. - (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) - }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) - - val bootstrappedPod = KubernetesUtils.appendInitContainer( - pod, mayBeSecretsMountedInitContainer) - - (bootstrappedPod, podWithInitContainer.mainContainer) - }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) - new PodBuilder(bootstrappedPod) + new PodBuilder(maybeSecretsMountedPod) .editSpec() - .addToContainers(bootstrappedContainer) + .addToContainers(maybeSecretsMountedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index a942db6ae02db..ff5f6801da2a3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -33,7 +33,9 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && sc.deployMode == "client") { + if (masterURL.startsWith("k8s") && + sc.deployMode == "client" && + !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { throw new SparkException("Client mode is currently not supported for Kubernetes.") } @@ -44,74 +46,23 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val sparkConf = sc.getConf - val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) - val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) - - if (initContainerConfigMap.isEmpty) { - logWarning("The executor's init-container config map is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - if (initContainerConfigMapKey.isEmpty) { - logWarning("The executor's init-container config map key is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - // Only set up the bootstrap if they've provided both the config map key and the config map - // name. The config map might not be provided if init-containers aren't being used to - // bootstrap dependencies. - val initContainerBootstrap = for { - configMap <- initContainerConfigMap - configMapKey <- initContainerConfigMapKey - } yield { - val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - new InitContainerBootstrap( - initContainerImage, - sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), - sparkConf.get(JARS_DOWNLOAD_LOCATION), - sparkConf.get(FILES_DOWNLOAD_LOCATION), - configMap, - configMapKey, - SPARK_POD_EXECUTOR_ROLE, - sparkConf) - } - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) } else { None } - // Mount user-specified executor secrets also into the executor's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The executor's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && - executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, - Some(sparkConf.get(KUBERNETES_NAMESPACE)), + Some(sc.conf.get(KUBERNETES_NAMESPACE)), KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - sparkConf, + sc.conf, Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory( - sparkConf, - mountSecretBootstrap, - initContainerBootstrap, - initContainerMountSecretsBootstrap) + val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala deleted file mode 100644 index e0f29ecd0fb53..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.UUID - -import com.google.common.base.Charsets -import com.google.common.io.Files -import org.mockito.Mockito -import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.util.Utils - -class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { - - private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") - private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") - - private var downloadJarsDir: File = _ - private var downloadFilesDir: File = _ - private var downloadJarsSecretValue: String = _ - private var downloadFilesSecretValue: String = _ - private var fileFetcher: FileFetcher = _ - - override def beforeAll(): Unit = { - downloadJarsSecretValue = Files.toString( - new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) - downloadFilesSecretValue = Files.toString( - new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) - } - - before { - downloadJarsDir = Utils.createTempDir() - downloadFilesDir = Utils.createTempDir() - fileFetcher = mock[FileFetcher] - } - - after { - downloadJarsDir.delete() - downloadFilesDir.delete() - } - - test("Downloads from remote server should invoke the file fetcher") { - val sparkConf = getSparkConfForRemoteFileDownloads - val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) - initContainerUnderTest.run() - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) - } - - private def getSparkConfForRemoteFileDownloads: SparkConf = { - new SparkConf(true) - .set(INIT_CONTAINER_REMOTE_JARS, - "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") - .set(INIT_CONTAINER_REMOTE_FILES, - "http://localhost:9000/file.txt") - .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) - .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) - } - - private def createTempFile(extension: String): String = { - val dir = Utils.createTempDir() - val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") - Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) - file.getAbsolutePath - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index bf4ec04893204..6a501592f42a3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -38,6 +38,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_UID = "pod-id" private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" + private val KUBERNETES_RESOURCE_PREFIX = "resource-example" private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -61,6 +62,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ + private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) @@ -94,7 +96,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) @@ -108,62 +111,52 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { SecondTestConfigurationStep.containerName) } - test("The client should create the secondary Kubernetes resources.") { + test("The client should create Kubernetes resources") { + val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" + val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( submissionSteps, - new SparkConf(false), + new SparkConf(false) + .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues - assert(otherCreatedResources.size === 1) - val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret] - assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(createdResource.getData.asScala === + assert(otherCreatedResources.size === 2) + val secrets = otherCreatedResources.toArray + .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val configMaps = otherCreatedResources.toArray + .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) + assert(secrets.nonEmpty) + val secret = secrets.head + assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) + assert(secret.getData.asScala === Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences) + val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) assert(ownerReference.getName === createdPod.getMetadata.getName) assert(ownerReference.getKind === DRIVER_POD_KIND) assert(ownerReference.getUid === DRIVER_POD_UID) assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) - } - - test("The client should attach the driver container with the appropriate JVM options.") { - val sparkConf = new SparkConf(false) - .set("spark.logConf", "true") - .set( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, - "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails") - val submissionClient = new Client( - submissionSteps, - sparkConf, - kubernetesClient, - false, - "spark", - loggingPodStatusWatcher) - submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue + assert(configMaps.nonEmpty) + val configMap = configMaps.head + assert(configMap.getMetadata.getName === + s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") + assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( + "spark.custom-conf=custom-conf-value")) val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env => - env.getName.startsWith(ENV_JAVA_OPT_PREFIX) - }.sortBy(_.getName) - assert(driverJvmOptsEnvs.size === 4) - - val expectedJvmOptsValues = Seq( - "-Dspark.logConf=true", - s"-D${SecondTestConfigurationStep.sparkConfKey}=" + - s"${SecondTestConfigurationStep.sparkConfValue}", - "-XX:+HeapDumpOnOutOfMemoryError", - "-XX:+PrintGCDetails") - driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach { - case ((resolvedEnv, expectedJvmOpt), index) => - assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index") - assert(resolvedEnv.getValue === expectedJvmOpt) - } + val driverEnv = driverContainer.getEnv.asScala.head + assert(driverEnv.getName === ENV_SPARK_CONF_DIR) + assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) + val driverMount = driverContainer.getVolumeMounts.asScala.head + assert(driverMount.getName === SPARK_CONF_VOLUME) + assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) } test("Waiting for app completion should stall on the watcher") { @@ -173,7 +166,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, true, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } @@ -209,13 +203,11 @@ private object FirstTestConfigurationStep extends DriverConfigurationStep { } private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" val annotationValue = "submitted" val sparkConfKey = "spark.custom-conf" val sparkConfValue = "custom-conf-value" val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val modifiedPod = new PodBuilder(driverSpec.driverPod) .editMetadata() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 033d303e946fd..df34d2dbcb5be 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -25,7 +25,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val DRIVER_IMAGE = "driver-image" private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" - private val LAUNCH_TIME = 975256L + private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") @@ -38,7 +38,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -49,15 +49,14 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep] - ) + classOf[DependencyResolutionStep]) } test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Option.empty, APP_NAME, MAIN_CLASS, @@ -67,31 +66,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { orchestrator, classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep] - ) - } - - test("Submission steps with an init-container.") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) - .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverInitContainerBootstrapStep]) + classOf[DriverKubernetesCredentialsStep]) } test("Submission steps with driver secrets to mount") { @@ -102,7 +77,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -122,7 +97,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, DRIVER_IMAGE) var orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, @@ -135,7 +110,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index b136f2c02ffba..ce068531c7673 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -73,16 +73,13 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - assert(preparedDriverSpec.driverContainer.getEnv.size === 7) + assert(preparedDriverSpec.driverContainer.getEnv.size === 4) val envs = preparedDriverSpec.driverContainer .getEnv .asScala .map(env => (env.getName, env.getValue)) .toMap assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(ENV_DRIVER_MEMORY) === "256M") - assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") @@ -112,7 +109,8 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX) + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") assert(resolvedSparkConf === expectedSparkConf) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala index 991b03cafb76c..ca43fc97dc991 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala @@ -29,24 +29,17 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DependencyResolutionStepSuite extends SparkFunSuite { private val SPARK_JARS = Seq( - "hdfs://localhost:9000/apps/jars/jar1.jar", - "file:///home/user/apps/jars/jar2.jar", - "local:///var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "local:///var/apps/jars/jar2.jar") private val SPARK_FILES = Seq( - "file:///home/user/apps/files/file1.txt", - "hdfs://localhost:9000/apps/files/file2.txt", - "local:///var/apps/files/file3.txt") - - private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars" - private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files" + "apps/files/file1.txt", + "local:///var/apps/files/file2.txt") test("Added dependencies should be resolved in Spark configuration and environment") { val dependencyResolutionStep = new DependencyResolutionStep( SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH) + SPARK_FILES) val driverPod = new PodBuilder().build() val baseDriverSpec = KubernetesDriverSpec( driverPod = driverPod, @@ -58,24 +51,19 @@ class DependencyResolutionStepSuite extends SparkFunSuite { assert(preparedDriverSpec.otherKubernetesResources.isEmpty) val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet val expectedResolvedSparkJars = Set( - "hdfs://localhost:9000/apps/jars/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "/var/apps/jars/jar2.jar") assert(resolvedSparkJars === expectedResolvedSparkJars) val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet val expectedResolvedSparkFiles = Set( - s"$FILES_DOWNLOAD_PATH/file1.txt", - s"hdfs://localhost:9000/apps/files/file2.txt", - s"/var/apps/files/file3.txt") + "apps/files/file1.txt", + "/var/apps/files/file2.txt") assert(resolvedSparkFiles === expectedResolvedSparkFiles) val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala assert(driverEnv.size === 1) assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = Set( - s"$JARS_DOWNLOAD_PATH/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + val expectedResolvedDriverClasspath = expectedResolvedSparkJars assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala deleted file mode 100644 index 758871e2ba356..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringReader -import java.util.Properties - -import scala.collection.JavaConverters._ - -import com.google.common.collect.Maps -import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} -import org.apache.spark.util.Utils - -class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { - - private val CONFIG_MAP_NAME = "spark-init-config-map" - private val CONFIG_MAP_KEY = "spark-init-config-map-key" - - test("The init container bootstrap step should use all of the init container steps") { - val baseDriverSpec = KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val initContainerSteps = Seq( - FirstTestInitContainerConfigurationStep, - SecondTestInitContainerConfigurationStep) - val bootstrapStep = new DriverInitContainerBootstrapStep( - initContainerSteps, - CONFIG_MAP_NAME, - CONFIG_MAP_KEY) - - val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === - FirstTestInitContainerConfigurationStep.additionalLabels) - val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(additionalDriverEnv.size === 1) - assert(additionalDriverEnv.head.getName === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) - assert(additionalDriverEnv.head.getValue === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) - - assert(preparedDriverSpec.otherKubernetesResources.size === 2) - assert(preparedDriverSpec.otherKubernetesResources.contains( - FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) - assert(preparedDriverSpec.otherKubernetesResources.exists { - case configMap: ConfigMap => - val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME - val configMapData = configMap.getData.asScala - val hasCorrectNumberOfEntries = configMapData.size == 1 - val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) - val initContainerProperties = new Properties() - Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { - initContainerProperties.load(_) - } - val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala - val expectedInitContainerProperties = Map( - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) - val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties - hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties - - case _ => false - }) - - val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers - assert(initContainers.size() === 1) - val initContainerEnv = initContainers.get(0).getEnv.asScala - assert(initContainerEnv.size === 1) - assert(initContainerEnv.head.getName === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) - assert(initContainerEnv.head.getValue === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) - - val expectedSparkConf = Map( - INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - } -} - -private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - - val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") - val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" - val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" - val additionalKubernetesResource = new SecretBuilder() - .withNewMetadata() - .withName("test-secret") - .endMetadata() - .addToData("secret-key", "secret-value") - .build() - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val driverPod = new PodBuilder(initContainerSpec.driverPod) - .editOrNewMetadata() - .addToLabels(additionalLabels.asJava) - .endMetadata() - .build() - val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) - .addNewEnv() - .withName(additionalMainContainerEnvKey) - .withValue(additionalMainContainerEnvValue) - .endEnv() - .build() - initContainerSpec.copy( - driverPod = driverPod, - driverContainer = mainContainer, - dependentResources = initContainerSpec.dependentResources ++ - Seq(additionalKubernetesResource)) - } -} - -private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" - val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" - val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" - val additionalInitContainerPropertyValue = "testvalue" - val additionalDriverSparkConfKey = "spark.driver.testkey" - val additionalDriverSparkConfValue = "spark.driver.testvalue" - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val initContainer = new ContainerBuilder(initContainerSpec.initContainer) - .addNewEnv() - .withName(additionalInitContainerEnvKey) - .withValue(additionalInitContainerEnvValue) - .endEnv() - .build() - val initContainerProperties = initContainerSpec.properties ++ - Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) - val driverSparkConf = initContainerSpec.driverSparkConf ++ - Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) - initContainerSpec.copy( - initContainer = initContainer, - properties = initContainerProperties, - driverSparkConf = driverSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala deleted file mode 100644 index 4553f9f6b1d45..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito.when -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ - -class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val POD_LABEL = Map("bootstrap" -> "true") - private val INIT_CONTAINER_NAME = "init-container" - private val DRIVER_CONTAINER_NAME = "driver-container" - - @Mock - private var podAndInitContainerBootstrap : InitContainerBootstrap = _ - - before { - MockitoAnnotations.initMocks(this) - when(podAndInitContainerBootstrap.bootstrapInitContainer( - any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { - override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { - val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) - pod.copy( - pod = new PodBuilder(pod.pod) - .withNewMetadata() - .addToLabels("bootstrap", "true") - .endMetadata() - .withNewSpec().endSpec() - .build(), - initContainer = new ContainerBuilder() - .withName(INIT_CONTAINER_NAME) - .build(), - mainContainer = new ContainerBuilder() - .withName(DRIVER_CONTAINER_NAME) - .build() - )}}) - } - - test("additionalDriverSparkConf with mix of remote files and jars") { - val baseInitStep = new BasicInitContainerConfigurationStep( - SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - podAndInitContainerBootstrap) - val expectedDriverSparkConf = Map( - JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, - INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", - INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") - val initContainerSpec = InitContainerSpec( - Map.empty[String, String], - Map.empty[String, String], - new Container(), - new Container(), - new Pod, - Seq.empty[HasMetadata]) - val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) - assert(expectedDriverSparkConf === returnContainerSpec.properties) - assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) - assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala deleted file mode 100644 index 09b42e4484d86..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -class InitContainerConfigOrchestratorSuite extends SparkFunSuite { - - private val DOCKER_IMAGE = "init-container" - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" - private val CUSTOM_LABEL_KEY = "customLabel" - private val CUSTOM_LABEL_VALUE = "customLabelValue" - private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" - private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("including basic configuration step") { - val sparkConf = new SparkConf(true) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.lengthCompare(1) == 0) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - } - - test("including step to mount user-specified secrets") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.length === 2) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala deleted file mode 100644 index 7ac0bde80dfe6..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} - -class InitContainerMountSecretsStepSuite extends SparkFunSuite { - - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("mounts all given secrets") { - val baseInitContainerSpec = InitContainerSpec( - Map.empty, - Map.empty, - new ContainerBuilder().build(), - new ContainerBuilder().build(), - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - Seq.empty) - val secretNamesToMountPaths = Map( - SECRET_FOO -> SECRET_MOUNT_PATH, - SECRET_BAR -> SECRET_MOUNT_PATH) - - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) - val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( - baseInitContainerSpec) - val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.containerHasVolume( - initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a3c615be031d2..7755b93835047 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ -import org.mockito.{AdditionalAnswers, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito._ +import org.mockito.MockitoAnnotations import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.MountSecretsBootstrap class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { @@ -55,10 +53,11 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None, None, None) + val factory = new ExecutorPodFactory(baseConf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -89,7 +88,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -101,7 +100,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -116,11 +115,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val conf = baseConf.clone() val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - None, - None) + val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -138,50 +133,6 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } - test("init-container bootstrap step adds an init container") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - - val factory = new ExecutorPodFactory( - conf, - None, - Some(initContainerBootstrap), - None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getInitContainers.size() === 1) - checkOwnerReferences(executor, driverPodUid) - } - - test("init-container with secrets mount bootstrap") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - Some(initContainerBootstrap), - Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getVolumes.size() === 1) - assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) - assert(executor.getSpec.getInitContainers.size() === 1) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) - - checkOwnerReferences(executor, driverPodUid) - } - // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) @@ -197,8 +148,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null, - ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 491b7cf692478..9badf8556afc3 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -40,7 +40,6 @@ RUN set -ex && \ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY conf /opt/spark/conf COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples COPY data /opt/spark/data diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index d0cf284f035ea..3e166116aa3fd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -56,14 +56,10 @@ fi case "$SPARK_K8S_CMD" in driver) CMD=( - ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" - -cp "$SPARK_CLASSPATH" - -Xms$SPARK_DRIVER_MEMORY - -Xmx$SPARK_DRIVER_MEMORY - -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS - $SPARK_DRIVER_CLASS - $SPARK_DRIVER_ARGS + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" ) ;; @@ -83,14 +79,6 @@ case "$SPARK_K8S_CMD" in ) ;; - init) - CMD=( - "$SPARK_HOME/bin/spark-class" - "org.apache.spark.deploy.k8s.SparkPodInitContainer" - "$@" - ) - ;; - *) echo "Unknown command: $SPARK_K8S_CMD" 1>&2 exit 1 From 5f4deff19511b6870f056eba5489104b9cac05a9 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 19 Mar 2018 18:02:04 -0700 Subject: [PATCH 183/593] [SPARK-23660] Fix exception in yarn cluster mode when application ended fast ## What changes were proposed in this pull request? Yarn throws the following exception in cluster mode when the application is really small: ``` 18/03/07 23:34:22 WARN netty.NettyRpcEnv: Ignored failure: java.util.concurrent.RejectedExecutionException: Task java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask7c974942 rejected from java.util.concurrent.ScheduledThreadPoolExecutor1eea9d2d[Terminated, pool size = 0, active threads = 0, queued tasks = 0, completed tasks = 0] 18/03/07 23:34:22 ERROR yarn.ApplicationMaster: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:205) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:76) at org.apache.spark.deploy.yarn.YarnAllocator.(YarnAllocator.scala:102) at org.apache.spark.deploy.yarn.YarnRMClient.register(YarnRMClient.scala:77) at org.apache.spark.deploy.yarn.ApplicationMaster.registerAM(ApplicationMaster.scala:450) at org.apache.spark.deploy.yarn.ApplicationMaster.runDriver(ApplicationMaster.scala:493) at org.apache.spark.deploy.yarn.ApplicationMaster.org$apache$spark$deploy$yarn$ApplicationMaster$$runImpl(ApplicationMaster.scala:345) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply$mcV$sp(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$5.run(ApplicationMaster.scala:810) at java.security.AccessController.doPrivileged(Native Method) at javax.security.auth.Subject.doAs(Subject.java:422) at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1920) at org.apache.spark.deploy.yarn.ApplicationMaster.doAsUser(ApplicationMaster.scala:809) at org.apache.spark.deploy.yarn.ApplicationMaster.run(ApplicationMaster.scala:259) at org.apache.spark.deploy.yarn.ApplicationMaster$.main(ApplicationMaster.scala:834) at org.apache.spark.deploy.yarn.ApplicationMaster.main(ApplicationMaster.scala) Caused by: org.apache.spark.rpc.RpcEnvStoppedException: RpcEnv already stopped. at org.apache.spark.rpc.netty.Dispatcher.postMessage(Dispatcher.scala:158) at org.apache.spark.rpc.netty.Dispatcher.postLocalMessage(Dispatcher.scala:135) at org.apache.spark.rpc.netty.NettyRpcEnv.ask(NettyRpcEnv.scala:229) at org.apache.spark.rpc.netty.NettyRpcEndpointRef.ask(NettyRpcEnv.scala:523) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:91) ... 17 more 18/03/07 23:34:22 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 13, (reason: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: ) ``` Example application: ``` object ExampleApp { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("ExampleApp") val sc = new SparkContext(conf) try { // Do nothing } finally { sc.stop() } } ``` This PR pauses user class thread after `SparkContext` created and keeps it so until application master initialises properly. ## How was this patch tested? Automated: Existing unit tests Manual: Application submitted into small cluster Author: Gabor Somogyi Closes #20807 from gaborgsomogyi/SPARK-23660. --- .../spark/deploy/yarn/ApplicationMaster.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2f88feb0f1fdf..6e35d23def6f0 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -418,7 +418,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def sparkContextInitialized(sc: SparkContext) = { - sparkContextPromise.success(sc) + sparkContextPromise.synchronized { + // Notify runDriver function that SparkContext is available + sparkContextPromise.success(sc) + // Pause the user class thread in order to make proper initialization in runDriver function. + sparkContextPromise.wait() + } + } + + private def resumeDriver(): Unit = { + // When initialization in runDriver happened the user class thread has to be resumed. + sparkContextPromise.synchronized { + sparkContextPromise.notify() + } } private def registerAM( @@ -497,6 +509,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // if the user app did not create a SparkContext. throw new IllegalStateException("User did not initialize spark context!") } + resumeDriver() userClassThread.join() } catch { case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => @@ -506,6 +519,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") + } finally { + resumeDriver() } } From 566321852b2d60641fe86acbc8914b4a7063b58e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Mar 2018 21:25:37 -0700 Subject: [PATCH 184/593] [SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible ## What changes were proposed in this pull request? https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e added an useful util ```python contextmanager def sql_conf(self, pairs): ... ``` to allow configuration set/unset within a block: ```python with self.sql_conf({"spark.blah.blah.blah", "blah"}) # test codes ``` This PR proposes to use this util where possible in PySpark tests. Note that there look already few places affecting tests without restoring the original value back in unittest classes. ## How was this patch tested? Manually tested via: ``` ./run-tests --modules=pyspark-sql --python-executables=python2 ./run-tests --modules=pyspark-sql --python-executables=python3 ``` Author: hyukjinkwon Closes #20830 from HyukjinKwon/cleanup-sql-conf. --- python/pyspark/sql/tests.py | 130 ++++++++++++++---------------------- 1 file changed, 50 insertions(+), 80 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a0d547ad620e5..39d6c5226f138 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2461,17 +2461,13 @@ def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") - try: - self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) - finally: - # We should unset this. Otherwise, other tests are affected. - self.spark.conf.unset("spark.sql.crossJoin.enabled") # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): @@ -2943,21 +2939,18 @@ def test_create_dateframe_from_pandas_with_dst(self): self.assertPandasEqual(pdf, df.toPandas()) orig_env_tz = os.environ.get('TZ', None) - orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') try: tz = 'America/Los_Angeles' os.environ['TZ'] = tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', tz) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) finally: del os.environ['TZ'] if orig_env_tz is not None: os.environ['TZ'] = orig_env_tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) class HiveSparkSubmitTests(SparkSubmitTests): @@ -3562,12 +3555,11 @@ def test_null_conversion(self): self.assertTrue(all([c == 1 for c in null_counts])) def _toPandas_arrow_toggle(self, df): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): pdf = df.toPandas() - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + pdf_arrow = df.toPandas() + return pdf, pdf_arrow def test_toPandas_arrow_toggle(self): @@ -3579,16 +3571,17 @@ def test_toPandas_arrow_toggle(self): def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) self.assertPandasEqual(pdf_arrow_ny, pdf_ny) @@ -3601,8 +3594,6 @@ def test_toPandas_respect_session_timezone(self): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) self.assertPandasEqual(pdf_ny, pdf_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() @@ -3618,12 +3609,11 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow def test_createDataFrame_toggle(self): @@ -3634,18 +3624,18 @@ def test_createDataFrame_toggle(self): def test_createDataFrame_respect_session_timezone(self): from datetime import timedelta pdf = self.create_pandas_data_frame() - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) result_ny = df_no_arrow_ny.collect() result_arrow_ny = df_arrow_ny.collect() @@ -3658,8 +3648,6 @@ def test_createDataFrame_respect_session_timezone(self): for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() @@ -4336,9 +4324,7 @@ def gen_timestamps(id): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd - orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) - try: + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @pandas_udf(returnType=LongType()) @@ -4348,11 +4334,6 @@ def check_records_per_batch(x): result = df.select(check_records_per_batch(col("id"))).collect() for (r,) in result: self.assertTrue(r <= 3) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - else: - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) def test_vectorized_udf_timestamps_respect_session_timezone(self): from pyspark.sql.functions import pandas_udf, col @@ -4371,30 +4352,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): internal_value = pandas_udf( lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() self.assertNotEqual(result_ny, result_la) self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations @@ -5170,9 +5148,7 @@ def test_complex_expressions(self): def test_retain_group_columns(self): from pyspark.sql.functions import sum, lit, col - orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) - self.spark.conf.set("spark.sql.retainGroupColumns", False) - try: + with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -5180,12 +5156,6 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.retainGroupColumns") - else: - self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) - def test_invalid_args(self): from pyspark.sql.functions import mean From 5e7bc2acef4a1e11d0d8056ef5c12cd5c8f220da Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 20 Mar 2018 10:34:56 -0700 Subject: [PATCH 185/593] [SPARK-23649][SQL] Skipping chars disallowed in UTF-8 ## What changes were proposed in this pull request? The mapping of UTF-8 char's first byte to char's size doesn't cover whole range 0-255. It is defined only for 0-253: https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L60-L65 https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L190 If the first byte of a char is 253-255, IndexOutOfBoundsException is thrown. Besides of that values for 244-252 are not correct according to recent unicode standard for UTF-8: http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf As a consequence of the exception above, the length of input string in UTF-8 encoding cannot be calculated if the string contains chars started from 253 code. It is visible on user's side as for example crashing of schema inferring of csv file which contains such chars but the file can be read if the schema is specified explicitly or if the mode set to multiline. The proposed changes build correct mapping of first byte of UTF-8 char to its size (now it covers all cases) and skip disallowed chars (counts it as one octet). ## How was this patch tested? Added a test and a file with a char which is disallowed in UTF-8 - 0xFF. Author: Maxim Gekk Closes #20796 from MaxGekk/skip-wrong-utf8-chars. --- .../apache/spark/unsafe/types/UTF8String.java | 48 +++++++++++++++---- .../spark/unsafe/types/UTF8StringSuite.java | 23 ++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b0d0c44823e68..5d468aed42337 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -57,12 +57,43 @@ public final class UTF8String implements Comparable, Externalizable, public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } - private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6}; + /** + * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which + * indicates the size of the char. See Unicode standard in page 126, Table 3-6: + * http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf + * + * Binary Hex Comments + * 0xxxxxxx 0x00..0x7F Only byte of a 1-byte character encoding + * 10xxxxxx 0x80..0xBF Continuation bytes (1-3 continuation bytes) + * 110xxxxx 0xC0..0xDF First byte of a 2-byte character encoding + * 1110xxxx 0xE0..0xEF First byte of a 3-byte character encoding + * 11110xxx 0xF0..0xF7 First byte of a 4-byte character encoding + * + * As a consequence of the well-formedness conditions specified in + * Table 3-7 (page 126), the following byte values are disallowed in UTF-8: + * C0–C1, F5–FF. + */ + private static byte[] bytesOfCodePointInUTF8 = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F + // Continuation bytes cannot appear as the first byte + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF + 0, 0, // 0xC0..0xC1 - disallowed in UTF-8 + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF + 4, 4, 4, 4, 4, // 0xF0..0xF4 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 + }; private static final boolean IS_LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; @@ -187,8 +218,9 @@ public void writeTo(OutputStream out) throws IOException { * @param b The first byte of a code point */ private static int numBytesForFirstByte(final byte b) { - final int offset = (b & 0xFF) - 192; - return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + final int offset = b & 0xFF; + byte numBytes = bytesOfCodePointInUTF8[offset]; + return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b303fa5bc6c5..7c34d419574ef 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -58,8 +58,12 @@ private static void checkBasic(String str, int len) { @Test public void basicTest() { checkBasic("", 0); - checkBasic("hello", 5); + checkBasic("¡", 1); // 2 bytes char + checkBasic("ку", 2); // 2 * 2 bytes chars + checkBasic("hello", 5); // 5 * 1 byte chars checkBasic("大 千 世 界", 7); + checkBasic("︽﹋%", 3); // 3 * 3 bytes chars + checkBasic("\uD83E\uDD19", 1); // 4 bytes char } @Test @@ -791,4 +795,21 @@ public void trimRightWithTrimString() { assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); } + + @Test + public void skipWrongFirstByte() { + int[] wrongFirstBytes = { + 0x80, 0x9F, 0xBF, // Skip Continuation bytes + 0xC0, 0xC2, // 0xC0..0xC1 - disallowed in UTF-8 + // 0xF5..0xFF - disallowed in UTF-8 + 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, + 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF + }; + byte[] c = new byte[1]; + + for (int i = 0; i < wrongFirstBytes.length; ++i) { + c[0] = (byte)wrongFirstBytes[i]; + assertEquals(fromBytes(c).numChars(), 1); + } + } } From 7f5e8aa2606b0ee0297ceb6f4603bd368e3b0291 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 20 Mar 2018 11:14:34 -0700 Subject: [PATCH 186/593] [SPARK-21898][ML] Feature parity for KolmogorovSmirnovTest in MLlib ## What changes were proposed in this pull request? Feature parity for KolmogorovSmirnovTest in MLlib. Implement `DataFrame` interface for `KolmogorovSmirnovTest` in `mllib.stat`. ## How was this patch tested? Test suite added. Author: WeichenXu Author: jkbradley Closes #19108 from WeichenXu123/ml-ks-test. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 113 ++++++++++++++ .../stat/JavaKolmogorovSmirnovTestSuite.java | 84 +++++++++++ .../ml/stat/KolmogorovSmirnovTestSuite.scala | 140 ++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala new file mode 100644 index 0000000000000..8d80e7768cb6e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.annotation.varargs + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.function.Function +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col + +/** + * :: Experimental :: + * + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see + * Kolmogorov-Smirnov test (Wikipedia) + */ +@Experimental +@Since("2.4.0") +object KolmogorovSmirnovTest { + + /** Used to construct output schema of test */ + private case class KolmogorovSmirnovTestResult( + pValue: Double, + statistic: Double) + + private def getSampleRDD(dataset: DataFrame, sampleCol: String): RDD[Double] = { + SchemaUtils.checkNumericType(dataset.schema, sampleCol) + import dataset.sparkSession.implicits._ + dataset.select(col(sampleCol).cast("double")).as[Double].rdd + } + + /** + * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } + + /** + * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) + } + + /** + * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability + * distribution equality. Currently supports the normal distribution, taking as parameters + * the mean and standard deviation. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param distName a `String` name for a theoretical distribution, currently only support "norm". + * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + @varargs + def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java new file mode 100644 index 0000000000000..021272dd5a40c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + + +public class JavaKolmogorovSmirnovTestSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = Arrays.asList(0.1, 1.1, 10.1, -1.1); + + dataset = spark.createDataset(points, Encoders.DOUBLE()).toDF("sample"); + } + + @Test + public void testKSTestCDF() { + // Create theoretical distributions + NormalDistribution stdNormalDist = new NormalDistribution(0, 1); + + // set seeds + Long seed = 10L; + stdNormalDist.reseedRandomGenerator(seed); + Function stdNormalCDF = (x) -> stdNormalDist.cumulativeProbability(x); + + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", stdNormalCDF).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } + + @Test + public void testKSTestNamedDistribution() { + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", "norm", 0.0, 1.0).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala new file mode 100644 index 0000000000000..1312de3a1b522 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.commons.math3.distribution.{ExponentialDistribution, NormalDistribution, + RealDistribution, UniformRealDistribution} +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => Math3KSTest} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class KolmogorovSmirnovTestSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + def apacheCommonMath3EquivalenceTest( + sampleDist: RealDistribution, + theoreticalDist: RealDistribution, + theoreticalDistByName: (String, Array[Double]), + rejectNullHypothesis: Boolean): Unit = { + + // set seeds + val seed = 10L + sampleDist.reseedRandomGenerator(seed) + if (theoreticalDist != null) { + theoreticalDist.reseedRandomGenerator(seed) + } + + // Sample data from the distributions and parallelize it + val n = 100000 + val sampledArray = sampleDist.sample(n) + val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample") + + // Use a apache math commons local KS test to verify calculations + val ksTest = new Math3KSTest() + val pThreshold = 0.05 + + // Comparing a standard normal sample to a standard normal distribution + val Row(pValue1: Double, statistic1: Double) = + if (theoreticalDist != null) { + val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x) + KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head() + } else { + KolmogorovSmirnovTest.test(sampledDF, "sample", + theoreticalDistByName._1, + theoreticalDistByName._2: _* + ).head() + } + val theoreticalDistMath3 = if (theoreticalDist == null) { + assert(theoreticalDistByName._1 == "norm") + val params = theoreticalDistByName._2 + new NormalDistribution(params(0), params(1)) + } else { + theoreticalDist + } + val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(theoreticalDistMath3, sampledArray) + val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n) + // Verify vs apache math commons ks test + assert(statistic1 ~== referenceStat1 relTol 1e-4) + assert(pValue1 ~== referencePVal1 relTol 1e-4) + + if (rejectNullHypothesis) { + assert(pValue1 < pThreshold) + } else { + assert(pValue1 > pThreshold) + } + } + + test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") { + // Create theoretical distributions + val stdNormalDist = new NormalDistribution(0.0, 1.0) + val expDist = new ExponentialDistribution(0.6) + val uniformDist = new UniformRealDistribution(0.0, 1.0) + val expDist2 = new ExponentialDistribution(0.2) + val stdNormByName = Tuple2("norm", Array(0.0, 1.0)) + + apacheCommonMath3EquivalenceTest(stdNormalDist, null, stdNormByName, false) + apacheCommonMath3EquivalenceTest(expDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(uniformDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(expDist, expDist2, null, true) + } + + test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") { + /* + Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample + > sessionInfo() + R version 3.2.0 (2015-04-16) + Platform: x86_64-apple-darwin13.4.0 (64-bit) + > set.seed(20) + > v <- rnorm(20) + > v + [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612 + [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222 + [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264 + [19] -0.08578219 0.38921440 + > ks.test(v, pnorm, alternative = "two.sided") + + One-sample Kolmogorov-Smirnov test + + data: v + D = 0.18874, p-value = 0.4223 + alternative hypothesis: two-sided + */ + + val rKSStat = 0.18874 + val rKSPVal = 0.4223 + val rData = sc.parallelize( + Array( + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ) + ).toDF("sample") + val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest + .test(rData, "sample", "norm", 0, 1).head() + assert(statistic ~== rKSStat relTol 1e-4) + assert(pValue ~== rKSPVal relTol 1e-4) + } +} From 2c4b9962fdf8c1beb66126ca41628c72eb6c2383 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 20 Mar 2018 11:46:51 -0700 Subject: [PATCH 187/593] [SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. ## What changes were proposed in this pull request? Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem. ## How was this patch tested? existing unit tests Author: Jose Torres Author: Jose Torres Closes #20726 from jose-torres/SPARK-23574. --- .../v2/reader/SupportsReportPartitioning.java | 3 ++ .../datasources/v2/DataSourceRDD.scala | 4 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 29 ++++++++++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- .../sql/streaming/StreamingQuerySuite.scala | 4 +-- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 5405a916951b8..607628746e873 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -23,6 +23,9 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. + * + * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving public interface SupportsReportPartitioning extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5ed0ba71e94c7..f85971be394b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index cb691ba297076..3a5e7bf89e142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. @@ -56,6 +58,15 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + SinglePartition + + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + SinglePartition + + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + SinglePartition + case s: SupportsReportPartitioning => new DataSourcePartitioning( s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) @@ -63,29 +74,33 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() + private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala case _ => reader.createDataReaderFactories().asScala.map { new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] - }.asJava + } } - private lazy val inputRDD: RDD[InternalRow] = reader match { + private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) - .asInstanceOf[RDD[InternalRow]] + r.createBatchDataReaderFactories().asScala + } + private lazy val inputRDD: RDD[InternalRow] = reader match { case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size())) + .askSync[Unit](SetReaderPartitions(readerFactories.size)) new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + case _ => new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cf02c0dda25d7..06754f01657d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1157a350461d8..e0a53272cd222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("SPARK-23574: no shuffle exchange with single partition") { + val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +} + class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 3f9aa0d1fa5be..ebc9a87b23f84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.durationMs.get("setOffsetRange") === 50) assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 0) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 350) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From 477d6bd7265e255fd43e53edda02019b32f29bb2 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 20 Mar 2018 13:27:50 -0700 Subject: [PATCH 188/593] [SPARK-23500][SQL] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? Complex type simplification optimizer rules were not applied to the entire plan, just the expressions reachable from the root node. This patch fixes the rules to transform the entire plan. ## How was this patch tested? New unit test + ran sql / core tests. Author: Henry Robinson Author: Henry Robinson Closes #20687 from henryr/spark-25000. --- .../sql/catalyst/optimizer/ComplexTypes.scala | 61 ++++++++----------- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/complexTypesSuite.scala | 55 +++++++++++++++-- 3 files changed, 76 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index be0009ec8c760..db7d6d3254bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,39 +18,39 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** -* push down operations into [[CreateNamedStructLike]]. -*/ -object SimplifyCreateStructOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field extraction + * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + */ +object SimplifyExtractValueOps extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // One place where this optimization is invalid is an aggregation where the select + // list expression is a function of a grouping expression: + // + // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) + // + // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this + // optimization for Aggregates (although this misses some cases where the optimization + // can be made). + case a: Aggregate => a + case p => p.transformExpressionsUp { + // Remove redundant field extraction. case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => createNamedStructLike.valExprs(ordinal) - } - } -} -/** -* push down operations into [[CreateArray]]. -*/ -object SimplifyCreateArrayOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field selection (array of structs) - case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) => - // instead f selecting the field on the entire array, - // select it from each member of the array. - // pushing down the operation this way open other optimizations opportunities - // (i.e. struct(...,x,...).x) + // Remove redundant array indexing. + case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => + // Instead of selecting the field on the entire array, select it from each member + // of the array. Pushing down the operation this way may open other optimizations + // opportunities (i.e. struct(...,x,...).x) CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name)))) - // push down item selection. + + // Remove redundant map lookup. case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) => - // instead of creating the array and then selecting one row, - // remove array creation altgether. + // Instead of creating the array and then selecting one row, remove array creation + // altogether. if (idx >= 0 && idx < elems.size) { // valid index elems(idx) @@ -58,18 +58,7 @@ object SimplifyCreateArrayOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - } - } -} - -/** -* push down operations into [[CreateMap]]. -*/ -object SimplifyCreateMapOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems) } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..2829d1d81eb1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateSerialization, RemoveRedundantAliases, RemoveRedundantProject, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps, + SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index de544ac314789..e44a6692ad8e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -44,14 +44,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: Nil + SimplifyExtractValueOps) :: Nil } val idAtt = ('id).long.notNull + val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt ) + lazy val relation = LocalRelation(idAtt, nullableIdAtt) test("explicit get from namedStruct") { val query = relation @@ -321,7 +320,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( CaseWhen(Seq( (EqualTo(2L, 'id), ('id + 1L)), - // these two are possible matches, we can't tell untill runtime + // these two are possible matches, we can't tell until runtime (EqualTo(2L, ('id + 1L)), ('id + 2L)), (EqualTo(2L, 'id + 2L), Literal.create(null, LongType)), // this is a definite match (two constants), @@ -331,4 +330,50 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .analyze comparePlans(Optimizer execute rel, expected) } + + test("SPARK-23500: Simplify complex ops that aren't at the plan root") { + val structRel = relation + .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") + .groupBy($"foo")("1").analyze + val structExpected = relation + .select('nullable_id as "foo") + .groupBy($"foo")("1").analyze + comparePlans(Optimizer execute structRel, structExpected) + + // These tests must use nullable attributes from the base relation for the following reason: + // in the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to a1, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. In the 'expected' plans, + // the grouping expressions have the same nullability as the original attribute in the + // relation. If that attribute is non-nullable, the tests will fail as the plans will + // compare differently, so for these tests we must use a nullable attribute. See + // SPARK-23634. + val arrayRel = relation + .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") + .groupBy($"a1")("1").analyze + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze + comparePlans(Optimizer execute arrayRel, arrayExpected) + + val mapRel = relation + .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") + .groupBy($"m1")("1").analyze + val mapExpected = relation + .select('nullable_id as "m1") + .groupBy($"m1")("1").analyze + comparePlans(Optimizer execute mapRel, mapExpected) + } + + test("SPARK-23500: Ensure that aggregation expressions are not simplified") { + // Make sure that aggregation exprs are correctly ignored. Maps can't be used in + // grouping exprs so aren't tested here. + val structAggRel = relation.groupBy( + CreateNamedStruct(Seq("att1", 'nullable_id)))( + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze + comparePlans(Optimizer execute structAggRel, structAggRel) + + val arrayAggRel = relation.groupBy( + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze + comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + } } From 983e8d9d64b6b1304c43ea6e5dffdc1078138ef9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 20 Mar 2018 23:17:49 -0700 Subject: [PATCH 189/593] [SPARK-23666][SQL] Do not display exprIds of Alias in user-facing info. ## What changes were proposed in this pull request? To drop `exprId`s for `Alias` in user-facing info., this pr added an entry for `Alias` in `NonSQLExpression.sql` ## How was this patch tested? Added tests in `UDFSuite`. Author: Takeshi Yamamuro Closes #20827 from maropu/SPARK-23666. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/expressions/Expression.scala | 1 + .../scala/org/apache/spark/sql/UDFSuite.scala | 132 ++++++++++-------- 3 files changed, 78 insertions(+), 56 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0e092e0e37ccf..5b47fd77f2cbc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1806,6 +1806,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d7f9e38915dd5..38caf67d465d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -288,6 +288,7 @@ trait NonSQLExpression extends Expression { final override def sql: String = { transform { case a: Attribute => new PrettyAttribute(a) + case a: Alias => PrettyAttribute(a.sql, a.dataType) }.toString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index af6a10b425b9f..21afdc7e2a33f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -144,73 +144,81 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a WHERE") { - spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + withTempView("integerData") { + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.createOrReplaceTempView("integerData") + val df = sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("integerData") - val result = - sql("SELECT * FROM integerData WHERE oneArgFilter(key)") - assert(result.count() === 20) + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } } test("UDF in a HAVING") { - spark.udf.register("havingFilter", (n: Long) => { n > 5 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT g, SUM(v) as s - | FROM groupData - | GROUP BY g - | HAVING havingFilter(s) - """.stripMargin) - - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } } test("UDF in a GROUP BY") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT SUM(v) - | FROM groupData - | GROUP BY groupFunction(v) - """.stripMargin) - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } } test("UDFs everywhere") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) - spark.udf.register("whereFilter", (n: Int) => { n < 150 }) - spark.udf.register("timesHundred", (n: Long) => { n * 100 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT timesHundred(SUM(v)) as v100 - | FROM groupData - | WHERE whereFilter(v) - | GROUP BY groupFunction(v) - | HAVING havingFilter(v100) - """.stripMargin) - assert(result.count() === 1) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } } test("struct UDF") { @@ -304,4 +312,16 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } + + test("SPARK-23666 Do not display exprId in argument names") { + withTempView("x") { + Seq(((1, 2), 3)).toDF("a", "b").createOrReplaceTempView("x") + spark.udf.register("f", (a: Int) => a) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + spark.sql("SELECT f(a._1) FROM x").show + } + assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) + } + } } From 500b21c3d6247015e550be7e144e9b4b26fe28be Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Mar 2018 10:19:02 -0500 Subject: [PATCH 190/593] [SPARK-23568][ML] Use metadata numAttributes if available in Silhouette ## What changes were proposed in this pull request? Silhouette need to know the number of features. This was taken using `first` and checking the size of the vector. Despite this works fine, if the number of attributes is present in metadata, we can avoid to trigger a job for this and use the metadata value. This can help improving performances of course. ## How was this patch tested? existing UTs + added UT Author: Marco Gaido Closes #20719 from mgaido91/SPARK-23568. --- .../ml/evaluation/ClusteringEvaluator.scala | 22 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 8d4ae562b3d2b..4353c46781e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} @@ -170,6 +171,15 @@ private[evaluation] abstract class Silhouette { def overallScore(df: DataFrame, scoreColumn: Column): Double = { df.select(avg(scoreColumn)).collect()(0).getDouble(0) } + + protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = { + val group = AttributeGroup.fromStructField(dataFrame.schema(columnName)) + if (group.size < 0) { + dataFrame.select(col(columnName)).first().getAs[Vector](0).size + } else { + group.size + } + } } /** @@ -360,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { df: DataFrame, predictionCol: String, featuresCol: String): Map[Double, ClusterStats] = { - val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm")) .rdd @@ -552,8 +562,11 @@ private[evaluation] object CosineSilhouette extends Silhouette { * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). */ - def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { - val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + def computeClusterStats( + df: DataFrame, + featuresCol: String, + predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) .rdd @@ -626,7 +639,8 @@ private[evaluation] object CosineSilhouette extends Silhouette { normalizeFeatureUDF(col(featuresCol))) // compute aggregate values for clusters needed by the algorithm - val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, + predictionCol) // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 3bf34770f5687..2c175ff68e0b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ @@ -100,4 +102,23 @@ class ClusteringEvaluatorSuite } } + test("SPARK-23568: we should use metadata to determine features number") { + val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size + val attrGroup = new AttributeGroup("features", attributesNum) + val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label") + require(AttributeGroup.fromStructField(df.schema("features")) + .numAttributes.isDefined, "numAttributes metadata should be defined") + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + + // with the proper metadata we compute correctly the result + assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5) + + val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1) + val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()), + $"label") + // with wrong metadata the evaluator throws an Exception + intercept[SparkException](evaluator.evaluate(dfWrong)) + } } From bf09f2f71276d3b3a84a8f89109bd785a066c3e6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 21 Mar 2018 09:39:14 -0700 Subject: [PATCH 191/593] [SPARK-10884][ML] Support prediction on single instance for regression and classification related models ## What changes were proposed in this pull request? Support prediction on single instance for regression and classification related models (i.e., PredictionModel, ClassificationModel and their sub classes). Add corresponding test cases. ## How was this patch tested? Test cases added. Author: WeichenXu Closes #19381 from WeichenXu123/single_prediction. --- .../scala/org/apache/spark/ml/Predictor.scala | 5 ++-- .../spark/ml/classification/Classifier.scala | 6 ++--- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 17 ++++++++++++- .../classification/GBTClassifierSuite.scala | 9 +++++++ .../ml/classification/LinearSVCSuite.scala | 6 +++++ .../LogisticRegressionSuite.scala | 9 +++++++ .../MultilayerPerceptronClassifierSuite.scala | 12 ++++++++++ .../ml/classification/NaiveBayesSuite.scala | 22 +++++++++++++++++ .../RandomForestClassifierSuite.scala | 16 +++++++++++++ .../DecisionTreeRegressorSuite.scala | 15 ++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 8 +++++++ .../GeneralizedLinearRegressionSuite.scala | 8 +++++++ .../ml/regression/LinearRegressionSuite.scala | 7 ++++++ .../RandomForestRegressorSuite.scala | 24 +++++++++++++++---- .../org/apache/spark/ml/util/MLTest.scala | 15 ++++++++++-- 25 files changed, 176 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 08b0cb9b8f6a5..d8f3dfa874439 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. */ - protected def predict(features: FeaturesType): Double + @Since("2.4.0") + def predict(features: FeaturesType): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 9d1d5aa1e0cff..7e5790ab70ee9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from `predictRaw()`. */ - override protected def predict(features: FeaturesType): Double = { + override def predict(features: FeaturesType): Double = { raw2prediction(predictRaw(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec52..65cce697d8202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f11bc1d8fe415..cd44489f618b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -267,7 +267,7 @@ class GBTClassificationModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization if (isDefined(thresholds)) { super.predict(features) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ce400f4f1faf7..8f950cd28c3aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -316,7 +316,7 @@ class LinearSVCModel private[classification] ( BLAS.dot(features, coefficients) + intercept } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { if (margin(features) > $(threshold)) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fa191604218db..3ae4db3f3f965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] ( * Predict label for the given feature vector. * The behavior of this can be adjusted using `thresholds`. */ - override protected def predict(features: Vector): Double = if (isMultinomial) { + override def predict(features: Vector): Double = if (isMultinomial) { super.predict(features) } else { // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index fd4c98f22132f..af2e4699924e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * Predict label for the given features. * This internal method is used to implement `transform()` and output [[predictionCol]]. */ - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { LabelConverter.decodeLabel(mlpModel.predict(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0291a57487c47..ad154fcd010cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f41d15b62dddd..6569ff2a5bfc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -230,7 +230,7 @@ class GBTRegressionModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 917a4d238d467..9f1f2405c428e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { predict(features, 0.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6d3fe7a6c748c..92510154d500e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] ( } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { dot(features, coefficients) + intercept } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 200b234b79978..2d594460c2475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index eeb0324187c5b..2930f4900d50e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, DecisionTreeClassificationModel](this, newTree, newData) } + test("prediction on single instance") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + testPredictionModelSinglePrediction(newTree, newData) + } + test("training with 1-category categorical feature") { val data = sc.parallelize(Seq( LabeledPoint(0, Vectors.dense(0, 2, 3)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 092b4a01d5b0d..57796069f6052 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, GBTClassificationModel](this, gbtModel, validationDataset) } + test("prediction on single instance") { + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF("label", "features") + val gbtModel = gbt.fit(trainingDataset) + + testPredictionModelSinglePrediction(gbtModel, trainingDataset) + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a93825b8a812d..c05c896df5cb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { dataset.as[LabeledPoint], estimator, modelEquals, 42L) } + test("prediction on single instance") { + val trainer = new LinearSVC() + val model = trainer.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(model, smallBinaryDataset) + } + test("linearSVC comparison with R e1071 and scikit-learn") { val trainer1 = new LinearSVC() .setRegParam(0.00002) // set regParam = 2.0 / datasize / c diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9987cbf6ba116..36b7e51f93d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } + test("prediction on single instance") { + val blor = new LogisticRegression().setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(blorModel, smallBinaryDataset) + val mlor = new LogisticRegression().setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) + testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset) + } + test("coefficients and intercept methods") { val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") val mlrModel = mlr.fit(smallMultinomialDataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index daa58a56896d7..6b5fe6e49ffea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe } } + test("prediction on single instance") { + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(dataset) + testPredictionModelSinglePrediction(model, dataset) + } + test("Predicted class probabilities: calibration on toy dataset") { val layers = Array[Int](4, 5, 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 49115c8a4db30..5f9ab98a2c3ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { Vector, NaiveBayesModel](this, model, testDataset) } + test("prediction on single instance") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val trainDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(trainDataset) + + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() + + testPredictionModelSinglePrediction(model, validationDataset) + } + test("Naive Bayes with weighted samples") { val numClasses = 3 def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 02a9d5c2a18c0..ba4a9cf082785 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, RandomForestClassificationModel](this, model, df) } + test("prediction on single instance") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + testPredictionModelSinglePrediction(model, df) + } + test("Fitting without numClasses in metadata") { val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 68a1218c23ece..29a438396516b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { assert(importances.toArray.forall(_ >= 0.0)) } + test("prediction on single instance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 11c593b521e65..fad11d078250f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(trainData.toDF()) + testPredictionModelSinglePrediction(model, validationData.toDF) + } + test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ef2ff94a5e213..d5bcbb221783e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(model.getLink === "identity") } + test("prediction on single instance") { + val glr = new GeneralizedLinearRegression + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + testPredictionModelSinglePrediction(model, datasetGaussianIdentity) + } + test("generalized linear regression: gaussian family against glm") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d42cb1714478f..9b19f63eba1bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val trainer = new LinearRegression + val model = trainer.fit(datasetWithDenseFeature) + + testPredictionModelSinglePrediction(model, datasetWithDenseFeature) + } + test("linear regression model with constant label") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 8b8e8a655f47b..e83c49f932973 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest{ +class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ @@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("prediction on single instance") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("Feature importance with toy data") { val rf = new RandomForestRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 795fd0e2ac0e4..76d41f9b23715 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,8 +22,9 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.Transformer -import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.ml.{PredictionModel, Transformer} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest @@ -136,4 +137,14 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => assert(hasExpectedMessage(exceptionOnStreamData)) } } + + def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _], + dataset: Dataset[_]): Unit = { + + model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol) + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction === model.predict(features)) + } + } } From 8d79113b812a91073d2c24a3a9ad94cc3b90b24a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 21 Mar 2018 09:46:47 -0700 Subject: [PATCH 192/593] [SPARK-23577][SQL] Supports custom line separator for text datasource ## What changes were proposed in this pull request? This PR proposes to add `lineSep` option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. ## How was this patch tested? Manual tests and unit tests were added. Author: hyukjinkwon Closes #20727 from HyukjinKwon/linesep-text. --- python/pyspark/sql/readwriter.py | 14 ++++--- python/pyspark/sql/streaming.py | 8 +++- python/pyspark/sql/tests.py | 24 ++++++++++- .../apache/spark/sql/DataFrameReader.scala | 30 ++++++++------ .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/HadoopFileLinesReader.scala | 23 ++++++++++- .../datasources/text/TextFileFormat.scala | 16 ++++---- .../datasources/text/TextOptions.scala | 12 ++++++ .../sql/streaming/DataStreamReader.scala | 12 +++++- .../datasources/text/TextSuite.scala | 40 +++++++++++++++++++ 10 files changed, 147 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index facc16bc53108..e5288636c596e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -304,16 +304,18 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths, wholetext=False): + def text(self, paths, wholetext=False, lineSep=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() @@ -322,7 +324,7 @@ def text(self, paths, wholetext=False): >>> df.collect() [Row(value=u'hello\\nthis')] """ - self._set_opts(wholetext=wholetext) + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -804,18 +806,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): self._jwrite.parquet(path) @since(1.6) - def text(self, path, compression=None): + def text(self, path, compression=None, lineSep=None): """Saves the content of the DataFrame in a text file at the specified path. :param path: the path in any Hadoop supported file system :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - self._set_opts(compression=compression) + self._set_opts(compression=compression, lineSep=lineSep) self._jwrite.text(path) @since(2.0) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e8966c20a8f42..07f9ac1b5aa9e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -531,17 +531,20 @@ def parquet(self, path): @ignore_unicode_prefix @since(2.0) - def text(self, path): + def text(self, path, wholetext=False, lineSep=None): """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. .. note:: Evolving. :param paths: string, or list of strings, for input path(s). + :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming @@ -549,6 +552,7 @@ def text(self, path): >>> "value" in str(text_sdf.schema) True """ + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.text(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39d6c5226f138..967cc83166f3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -648,7 +648,29 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", multiLine=True) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0139913aaa4e2..1a5e47508c070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -647,14 +647,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * You can set the following text-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    - * By default, each line in the text files is a new row in the resulting DataFrame. - * - * Usage example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.read.text("/path/to/spark/README.md") @@ -663,6 +656,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * + * You can set the following text-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input paths * @since 1.6.0 */ @@ -686,11 +687,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * You can set the following textFile-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: @@ -700,6 +696,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().textFile("/path/to/spark/README.md") * }}} * + * You can set the following textFile-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input path * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ed7a9100cc7f1..bb93889dc55e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -587,6 +587,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 83cf26c63a175..00a78f7343c59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -30,9 +30,22 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl /** * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. + * + * @param file A part (i.e. "block") of a single file that should be read line by line. + * @param lineSeparator A line separator that should be used for each line. If the value is `None`, + * it covers `\r`, `\r\n` and `\n`. + * @param conf Hadoop configuration + * + * @note The behavior when `lineSeparator` is `None` (covering `\r`, `\r\n` and `\n`) is defined + * by [[LineRecordReader]], not within Spark. */ class HadoopFileLinesReader( - file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { + file: PartitionedFile, + lineSeparator: Option[Array[Byte]], + conf: Configuration) extends Iterator[Text] with Closeable { + + def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf) + private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -42,7 +55,13 @@ class HadoopFileLinesReader( Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val reader = new LineRecordReader() + + val reader = lineSeparator match { + case Some(sep) => new LineRecordReader(sep) + // If the line separator is `None`, it covers `\r`, `\r\n` and `\n`. + case _ => new LineRecordReader() + } + reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index c661e9bd3b94c..9647f09867643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.execution.datasources.text -import java.io.Closeable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -89,7 +86,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -113,18 +110,18 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) + readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions) } private def readToUnsafeMem( conf: Broadcast[SerializableConfiguration], requiredSchema: StructType, - wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = { + textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = { (file: PartitionedFile) => { val confValue = conf.value.value - val reader = if (!wholeTextMode) { - new HadoopFileLinesReader(file, confValue) + val reader = if (!textOptions.wholeText) { + new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) } else { new HadoopFileWholeTextReader(file, confValue) } @@ -152,6 +149,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { class TextOutputWriter( path: String, dataSchema: StructType, + lineSeparator: Array[Byte], context: TaskAttemptContext) extends OutputWriter { @@ -162,7 +160,7 @@ class TextOutputWriter( val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) } - writer.write('\n') + writer.write(lineSeparator) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 2a661561ab51e..18698df9fd8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -39,9 +41,19 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean + private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => + require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInWrite: Array[Byte] = + lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } private[text] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c393dcdfdd7e5..9b17406a816b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -387,7 +387,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * Each line in the text files is a new row in the resulting DataFrame. For example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.readStream.text("/path/to/directory/") @@ -400,6 +400,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @since 2.0.0 @@ -413,7 +417,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * Each line in the text file is a new element in the resulting Dataset. For example: + * By default, each line in the text file is a new element in the resulting Dataset. For example: * {{{ * // Scala: * spark.readStream.textFile("/path/to/spark/README.md") @@ -426,6 +430,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @param path input path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 33287044f279e..e8a5299d6ba9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.TestUtils import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -172,6 +175,43 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-23577: Support line separator - lineSep: '$lineSep'") { + // Read + val values = Seq("a", "b", "\nc") + val data = values.mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, Seq("a", "b", "\nc").toDF()) + } + } + + // Write + withTempPath { path => + values.toDF().coalesce(1) + .write.option("lineSep", lineSep).text(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = values.toDF() + df.write.option("lineSep", lineSep).text(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + testLineSeparator(lineSep) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } From 98d0ea3f6091730285293321a50148f69e94c9cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 21 Mar 2018 09:52:28 -0700 Subject: [PATCH 193/593] [SPARK-23264][SQL] Fix scala.MatchError in literals.sql.out ## What changes were proposed in this pull request? To fix `scala.MatchError` in `literals.sql.out`, this pr added an entry for `CalendarIntervalType` in `QueryExecution.toHiveStructString`. ## How was this patch tested? Existing tests and added tests in `literals.sql` Author: Takeshi Yamamuro Closes #20872 from maropu/FixIntervalTests. --- .../spark/sql/execution/QueryExecution.scala | 2 ++ .../resources/sql-tests/inputs/literals.sql | 3 +++ .../sql-tests/results/literals.sql.out | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7cae24bf5976c..15379a0663f7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -155,6 +155,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -178,6 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes.contains(tpe) => other.toString } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 37b4b7606d12b..a743cf1ec2cde 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -105,3 +105,6 @@ select X'XuZ'; -- Hive literal_double test. SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; + +-- map + interval test +select map(1, interval 1 day, 2, interval 3 week); diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 95d4413148f64..b8c91dc8b59a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 43 +-- Number of queries: 44 -- !query 0 @@ -323,19 +323,17 @@ select timestamp '2016-33-11 20:54:00.000' -- !query 34 select interval 13.123456789 seconds, interval -13.123456789 second -- !query 34 schema -struct<> +struct -- !query 34 output -scala.MatchError -(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 13 seconds 123 milliseconds 456 microseconds interval -12 seconds -876 milliseconds -544 microseconds -- !query 35 select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond -- !query 35 schema -struct<> +struct -- !query 35 output -scala.MatchError -(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds 9 -- !query 36 @@ -416,3 +414,11 @@ SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> -- !query 42 output 3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 + + +-- !query 43 +select map(1, interval 1 day, 2, interval 3 week) +-- !query 43 schema +struct> +-- !query 43 output +{1:interval 1 days,2:interval 3 weeks} From 918c7e99afdcea05c36626e230636c4f8aabf82c Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 21 Mar 2018 10:06:26 -0700 Subject: [PATCH 194/593] [SPARK-23288][SS] Fix output metrics with parquet sink ## What changes were proposed in this pull request? Output metrics were not filled when parquet sink used. This PR fixes this problem by passing a `BasicWriteJobStatsTracker` in `FileStreamSink`. ## How was this patch tested? Additional unit test added. Author: Gabor Somogyi Closes #20745 from gaborgsomogyi/SPARK-23288. --- .../command/DataWritingCommand.scala | 11 +--- .../datasources/BasicWriteStatsTracker.scala | 25 +++++++-- .../execution/streaming/FileStreamSink.scala | 10 +++- .../sql/streaming/FileStreamSinkSuite.scala | 52 +++++++++++++++++++ 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a7..e11dbd201004d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} @@ -45,15 +44,7 @@ trait DataWritingCommand extends Command { // Output columns of the analyzed input query plan def outputColumns: Seq[Attribute] - lazy val metrics: Map[String, SQLMetric] = { - val sparkContext = SparkContext.getActive.get - Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") - ) - } + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 9dbbe9946ee99..69c03d862391e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker( totalNumOutput += summary.numRows } - metrics("numFiles").add(numFiles) - metrics("numOutputBytes").add(totalNumBytes) - metrics("numOutputRows").add(totalNumOutput) - metrics("numParts").add(numPartitions) + metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput) + metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) } } + +object BasicWriteJobStatsTracker { + private val NUM_FILES_KEY = "numFiles" + private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes" + private val NUM_OUTPUT_ROWS_KEY = "numOutputRows" + private val NUM_PARTS_KEY = "numParts" + + def metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"), + NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 87a17cebdc10c..b3d12f67b5d63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -97,6 +98,11 @@ class FileStreamSink( new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) + } + override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") @@ -131,7 +137,7 @@ class FileStreamSink( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, - statsTrackers = Nil, + statsTrackers = Seq(basicWriteJobStatsTracker), options = options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 31e5527d7366a..cf41d7e0e4fe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.hadoop.fs.Path +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -405,4 +406,55 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("SPARK-23288 writing and checking output metrics") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + + var query: StreamingQuery = null + + var numTasks = 0 + var recordsWritten: Long = 0L + var bytesWritten: Long = 0L + try { + spark.sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val outputMetrics = taskEnd.taskMetrics.outputMetrics + recordsWritten += outputMetrics.recordsWritten + bytesWritten += outputMetrics.bytesWritten + numTasks += 1 + } + }) + + query = + df.writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + + assert(numTasks > 0) + assert(recordsWritten === 5) + // This is heavily file type/version specific but should be filled + assert(bytesWritten > 0) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + } } From 2b89e4aa2e8bd8b88f6e5eb60d95c1a58e5c4ace Mon Sep 17 00:00:00 2001 From: akonopko Date: Wed, 21 Mar 2018 14:40:21 -0500 Subject: [PATCH 195/593] [SPARK-18580][DSTREAM][KAFKA] Add spark.streaming.backpressure.initialRate to direct Kafka streams ## What changes were proposed in this pull request? Add `spark.streaming.backpressure.initialRate` to direct Kafka Streams for Kafka 0.8 and 0.10 This is required in order to be able to use backpressure with huge lags, which cannot be processed at once. Without this parameter `DirectKafkaInputDStream` with backpressure enabled would try to get all the possible data from Kafka before adjusting consumption rate ## How was this patch tested? - Tests added to `org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala` and `org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala` - Manual tests on YARN cluster Author: akonopko Author: Alexander Konopko Closes #19431 from akonopko/SPARK-18580-initialrate. --- .../kafka010/DirectKafkaInputDStream.scala | 8 ++- .../kafka010/DirectKafkaStreamSuite.scala | 51 +++++++++++++++- .../kafka/DirectKafkaInputDStream.scala | 9 ++- .../kafka/DirectKafkaStreamSuite.scala | 59 ++++++++++++++++++- 4 files changed, 120 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 9cb2448fea0f4..215b7cab703fb 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -56,6 +56,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + val executorKafkaParams = { val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams) KafkaUtils.fixKafkaParams(ekp) @@ -126,7 +129,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 8524743ee2846..35e4678f2e3c8 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010 import java.io.File import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap } +import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -34,7 +34,7 @@ import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} @@ -617,6 +617,53 @@ class DirectKafkaStreamSuite ssc.stop() } + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) + } + kafkaStream.start() + + val input = Map(new TopicPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> maxMessagesPerPartition)) // we run for half a second + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = getKafkaParams() diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d6dd0744441e4..9297c39d170c4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -91,9 +91,16 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 3fea6cfd910bf..ecca38784e777 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.File -import java.util.Arrays +import java.util.{ Arrays, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -32,12 +32,11 @@ import kafka.serializer.StringDecoder import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils @@ -456,6 +455,60 @@ class DirectKafkaStreamSuite ssc.stop() } + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val topicPartitions = Set(TopicAndPartition(topic, 0)) + kafkaTestUtils.createTopic(topic, 1) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(topicPartitions) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) + } + kafkaStream.start() + + val input = Map(new TopicAndPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicAndPartition(topic, 0) -> maxMessagesPerPartition)) + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = Map( From a091ee676b8707819e94d92693956237310a6145 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 21 Mar 2018 13:52:03 -0700 Subject: [PATCH 196/593] [MINOR] Fix Java lint from new JavaKolmogorovSmirnovTestSuite ## What changes were proposed in this pull request? Fix lint-java from https://github.com/apache/spark/pull/19108 addition of JavaKolmogorovSmirnovTestSuite Author: Joseph K. Bradley Closes #20875 from jkbradley/kstest-lint-fix. --- .../spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java index 021272dd5a40c..830f668fe07b8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -18,18 +18,11 @@ package org.apache.spark.ml.stat; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.spark.ml.linalg.VectorUDT; -import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.junit.Test; import org.apache.spark.SharedSparkSession; From 0604beaff2baa2d0fed86c0c87fd2a16a1838b5f Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Wed, 21 Mar 2018 17:05:39 -0700 Subject: [PATCH 197/593] [SPARK-23729][CORE] Respect URI fragment when resolving globs Firstly, glob resolution will not result in swallowing the remote name part (that is preceded by the `#` sign) in case of `--files` or `--archives` options Moreover in the special case of multiple resolutions when the remote naming does not make sense and error is returned. Enhanced current test and wrote additional test for the error case Author: Mihaly Toth Closes #20853 from misutoth/glob-with-remote-name. --- .../apache/spark/deploy/DependencyUtils.scala | 34 +++++++++++---- .../org/apache/spark/deploy/SparkSubmit.scala | 13 ++++++ .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++---- 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ecc82d7ac8001..ab319c860ee69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy import java.io.File +import java.net.URI import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.{MutableURLClassLoader, Utils} private[deploy] object DependencyUtils { @@ -137,16 +138,31 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") Utils.stringToSeq(paths).flatMap { path => - val uri = Utils.resolveURI(path) - uri.getScheme match { - case "local" | "http" | "https" | "ftp" => Array(path) - case _ => - val fs = FileSystem.get(uri, hadoopConf) - Option(fs.globStatus(new Path(uri))).map { status => - status.filter(_.isFile).map(_.getPath.toUri.toString) - }.getOrElse(Array(path)) + val (base, fragment) = splitOnFragment(path) + (resolveGlobPath(base, hadoopConf), fragment) match { + case (resolved, Some(_)) if resolved.length > 1 => throw new SparkException( + s"${base.toString} resolves ambiguously to multiple files: ${resolved.mkString(",")}") + case (resolved, Some(namedAs)) => resolved.map(_ + "#" + namedAs) + case (resolved, _) => resolved } }.mkString(",") } + private def splitOnFragment(path: String): (URI, Option[String]) = { + val uri = Utils.resolveURI(path) + val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) + (withoutFragment, Option(uri.getFragment)) + } + + private def resolveGlobPath(uri: URI, hadoopConf: Configuration): Array[String] = { + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(uri.toString) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(uri.toString)) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329bde08718fe..3965f17f4b56e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -245,6 +245,19 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { + try { + doPrepareSubmitEnvironment(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage) + throw e + } + } + + private def doPrepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d265643a80b4e..2d0c192db4915 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -606,10 +606,13 @@ class SparkSubmitSuite } test("resolves command line argument paths correctly") { - val jars = "/jar1,/jar2" // --jars - val files = "local:/file1,file2" // --files - val archives = "file:/archive1,archive2" // --archives - val pyFiles = "py-file1,py-file2" // --py-files + val dir = Utils.createTempDir() + val archive = Paths.get(dir.toPath.toString, "single.zip") + Files.createFile(archive) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" // Test jars and files val clArgs = Seq( @@ -636,9 +639,10 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) - appArgs2.archives should be (Utils.resolveURIs(archives)) + appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.archives") should fullyMatch regex + ("file:/archive1,file:.*#archive3") // Test python files val clArgs3 = Seq( @@ -657,6 +661,29 @@ class SparkSubmitSuite conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } + test("ambiguous archive mapping results in error message") { + val dir = Utils.createTempDir() + val archive1 = Paths.get(dir.toPath.toString, "first.zip") + val archive2 = Paths.get(dir.toPath.toString, "second.zip") + Files.createFile(archive1) + Files.createFile(archive2) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + + testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + } + test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files From 95e51ff849a4c46cae463636b1ee393042469e7b Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 21 Mar 2018 21:21:36 -0700 Subject: [PATCH 198/593] [SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should save/restore CSE state correctly ## What changes were proposed in this pull request? Fixed `CodegenContext.withSubExprEliminationExprs()` so that it saves/restores CSE state correctly. ## How was this patch tested? Added new unit test to verify that the old CSE state is indeed saved and restored around the `withSubExprEliminationExprs()` call. Manually verified that this test fails without this patch. Author: Kris Mok Closes #20870 from rednaxelafx/codegen-subexpr-fix. --- .../expressions/codegen/CodeGenerator.scala | 16 +++---- .../expressions/CodeGenerationSuite.scala | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fe5e63ec0a2bb..84b1e3fbda876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,7 +402,7 @@ class CodegenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -921,14 +921,12 @@ class CodegenContext { newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs.clear - newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = newSubExprEliminationExprs val genCodes = f // Restore previous subExprEliminationExprs - subExprEliminationExprs.clear - oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = oldsubExprEliminationExprs genCodes } @@ -942,7 +940,7 @@ class CodegenContext { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree) @@ -955,10 +953,10 @@ class CodegenContext { // Generate the code for this expression tree. val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) - e.foreach(subExprEliminationExprs.put(_, state)) + e.foreach(localSubExprEliminationExprs.put(_, state)) eval.code.trim } - SubExprCodes(codes, subExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap) } /** @@ -1006,7 +1004,7 @@ class CodegenContext { subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) - e.foreach(subExprEliminationExprs.put(_, state)) + subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 64c13e8972036..398b6767654fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -442,4 +442,48 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(CodeGenerator.calculateParamLength( Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) } + + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } + + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + } + } } From 5c9eaa6b585e9febd782da8eb6490b24d0d39ff3 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 21 Mar 2018 21:49:02 -0700 Subject: [PATCH 199/593] [SPARK-23372][SQL] Writing empty struct in parquet fails during execution. It should fail earlier in the processing. ## What changes were proposed in this pull request? Currently we allow writing data frames with empty schema into a file based datasource for certain file formats such as JSON, ORC etc. For formats such as Parquet and Text, we raise error at different times of execution. For text format, we return error from the driver early on in processing where as for format such as parquet, the error is raised from executor. **Example** spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) **Results in** ``` SQL org.apache.parquet.schema.InvalidSchemaException: Cannot write a schema with an empty group: message spark_schema { } at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:27) at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:37) at org.apache.parquet.schema.MessageType.accept(MessageType.java:58) at org.apache.parquet.schema.TypeUtil.checkValidWriteSchema(TypeUtil.java:23) at org.apache.parquet.hadoop.ParquetFileWriter.(ParquetFileWriter.java:225) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:342) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:302) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:151) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.newOutputWriter(FileFormatWriter.scala:376) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.execute(FileFormatWriter.scala:387) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:278) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:276) at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1411) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:281) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:206) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:205) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread. ``` In this PR, we unify the error processing and raise error on attempt to write empty schema based dataframes into file based datasource (orc, parquet, text , csv, json etc) early on in the processing. ## How was this patch tested? Unit tests added in FileBasedDatasourceSuite. Author: Dilip Biswal Closes #20579 from dilipbiswal/spark-23372. --- docs/sql-programming-guide.md | 1 + .../execution/datasources/DataSource.scala | 26 ++++++++++++++++- .../spark/sql/FileBasedDataSourceSuite.scala | 28 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5b47fd77f2cbc..421e2eaf62bfb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1807,6 +1807,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 35fcff69b14d8..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils @@ -546,6 +546,7 @@ case class DataSource( case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => + DataSource.validateSchema(data.schema) planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") @@ -719,4 +720,27 @@ object DataSource extends Logging { } globPath } + + /** + * Called before writing into a FileFormat based data source to make sure the + * supplied schema is not empty. + * @param schema + */ + private def validateSchema(schema: StructType): Unit = { + def hasEmptySchema(schema: StructType): Boolean = { + schema.size == 0 || schema.find { + case StructField(_, b: StructType, _, _) => hasEmptySchema(b) + case _ => false + }.isDefined + } + + + if (hasEmptySchema(schema)) { + throw new AnalysisException( + s""" + |Datasource does not support writing empty or nested empty schemas. + |Please make sure the data schema has at least one or more column(s). + """.stripMargin) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index bd3071bcf9010..06303099f5310 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { @@ -107,6 +108,33 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + allFileBasedDataSources.foreach { format => + test(s"SPARK-23372 error while writing empty schema files using $format") { + withTempPath { outputPath => + val errMsg = intercept[AnalysisException] { + spark.emptyDataFrame.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + + // Nested empty schema + withTempPath { outputPath => + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", StructType(Nil)), + StructField("c", IntegerType) + )) + val df = spark.createDataFrame(sparkContext.emptyRDD[Row], schema) + val errMsg = intercept[AnalysisException] { + df.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => From 4d37008c78d7d6b8f8a649b375ecc090700eca4f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 19:57:32 +0100 Subject: [PATCH 200/593] [SPARK-23599][SQL] Use RandomUUIDGenerator in Uuid expression ## What changes were proposed in this pull request? As stated in Jira, there are problems with current `Uuid` expression which uses `java.util.UUID.randomUUID` for UUID generation. This patch uses the newly added `RandomUUIDGenerator` for UUID generation. So we can make `Uuid` deterministic between retries. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20861 from viirya/SPARK-23599-2. --- .../sql/catalyst/analysis/Analyzer.scala | 16 ++++ .../spark/sql/catalyst/expressions/misc.scala | 26 +++++-- .../ResolvedUuidExpressionsSuite.scala | 73 +++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 5 +- .../expressions/MiscExpressionsSuite.scala | 19 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 6 files changed, 136 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7848f88bda1c9..e821e96522f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ @@ -177,6 +178,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: + ResolvedUuidExpressions :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -1994,6 +1996,20 @@ class Analyzer( } } + /** + * Set the seed for random number generation in Uuid expressions. + */ + object ResolvedUuidExpressions extends Rule[LogicalPlan] { + private lazy val random = new Random() + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case p if p.resolved => p + case p => p transformExpressionsUp { + case Uuid(None) => Uuid(Some(random.nextLong())) + } + } + } + /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 38e4fe44b15ab..ec93620038cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -122,18 +123,33 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid() extends LeafExpression { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { - override lazy val deterministic: Boolean = false + def this() = this(None) + + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = + randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = + randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + val randomGen = ctx.freshName("randomGen") + ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, + forceInline = true, + useFreshName = false) + ctx.addPartitionInitializationStatement(s"$randomGen = " + + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + + s"${randomSeed.get}L + partitionIndex);") + ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala new file mode 100644 index 0000000000000..fe57c199b8744 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} + +/** + * Test suite for resolving Uuid expressions. + */ +class ResolvedUuidExpressionsSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val r = LocalRelation(a) + private lazy val uuid1 = Uuid().as('_uuid1) + private lazy val uuid2 = Uuid().as('_uuid2) + private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1Ref = uuid1.toAttribute + + private val analyzer = getAnalyzer(caseSensitive = true) + + private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { + plan.flatMap { + case p => + p.expressions.flatMap(_.collect { + case u: Uuid => u + }) + } + } + + test("analyzed plan sets random seed for Uuid expression") { + val plan = r.select(a, uuid1) + val resolvedPlan = analyzer.executeAndCheck(plan) + getUuidExpressions(resolvedPlan).foreach { u => + assert(u.resolved) + assert(u.randomSeed.isDefined) + } + } + + test("Uuid expressions should have different random seeds") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan = analyzer.executeAndCheck(plan) + assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) + } + + test("Different analyzed plans should have different random seeds in Uuids") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan1 = analyzer.executeAndCheck(plan) + val resolvedPlan2 = analyzer.executeAndCheck(plan) + val uuids1 = getUuidExpressions(resolvedPlan1) + val uuids2 = getUuidExpressions(resolvedPlan2) + assert(uuids1.distinct.length == 3) + assert(uuids2.distinct.length == 3) + assert(uuids1.intersect(uuids2).length == 0) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c6343b1cbf600..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -176,7 +176,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithGeneratedMutableProjection( + protected def evaluateWithGeneratedMutableProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( @@ -220,7 +220,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithUnsafeProjection( + protected def evaluateWithUnsafeProjection( expression: Expression, inputRow: InternalRow = EmptyRow, factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { @@ -233,6 +233,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { Alias(expression, s"Optimized($expression)2")() :: Nil), expression) + plan.initialize(0) plan(inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index c3d08bf68c7bb..3383d421f5616 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.io.PrintStream +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -42,8 +44,21 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("uuid") { - checkEvaluation(Length(Uuid()), 36) - assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) + checkEvaluation(Length(Uuid(Some(0))), 36) + val r = new Random() + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) === + evaluateWithGeneratedMutableProjection(Uuid(seed1))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) === + evaluateWithUnsafeProjection(Uuid(seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !== + evaluateWithGeneratedMutableProjection(Uuid(seed2))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== + evaluateWithUnsafeProjection(Uuid(seed2))) } test("PrintToStderr") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8b66f77b2f923..f7b3393f65cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -2264,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + + test("Uuid expressions should produce same results at retries in the same DataFrame") { + val df = spark.range(1).select($"id", new Column(Uuid())) + checkAnswer(df, df.collect()) + } } From a649fcf32a7e610da2a2b4e3d94f5d1372c825d6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Mar 2018 21:20:41 -0700 Subject: [PATCH 201/593] [MINOR][PYTHON] Remove unused codes in schema parsing logics of PySpark ## What changes were proposed in this pull request? This PR proposes to remove out unused codes, `_ignore_brackets_split` and `_BRACKETS`. `_ignore_brackets_split` was introduced in https://github.com/apache/spark/commit/d57daf1f7732a7ac54a91fe112deeda0a254f9ef to refactor and support `toDF("...")`; however, https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb replaced the logics here. Seems `_ignore_brackets_split` is not referred anymore. `_BRACKETS` was introduced in https://github.com/apache/spark/commit/880eabec37c69ce4e9594d7babfac291b0f93f50; however, all other usages were removed out in https://github.com/apache/spark/commit/648a8626b82d27d84db3e48bccfd73d020828586. This is rather a followup for https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb which I missed in that PR. ## How was this patch tested? Manually tested. Existing tests should cover this. I also double checked by `grep` in the whole repo. Author: hyukjinkwon Closes #20878 from HyukjinKwon/minor-remove-unused. --- python/pyspark/sql/types.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 826aab97e58db..5d5919e451b46 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -752,41 +752,6 @@ def __eq__(self, other): _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _ignore_brackets_split(s, separator): - """ - Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. - given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return - ["a", "d"]. - """ - parts = [] - buf = "" - level = 0 - for c in s: - if c in _BRACKETS.keys(): - level += 1 - buf += c - elif c in _BRACKETS.values(): - if level == 0: - raise ValueError("Brackets are not correctly paired: %s" % s) - level -= 1 - buf += c - elif c == separator and level > 0: - buf += c - elif c == separator: - parts.append(buf) - buf = "" - else: - buf += c - - if len(buf) == 0: - raise ValueError("The %s cannot be the last char: %s" % (separator, s)) - parts.append(buf) - return parts - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals From b2edc30db1dcc6102687d20c158a2700965fdf51 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 21:23:25 -0700 Subject: [PATCH 202/593] [SPARK-23614][SQL] Fix incorrect reuse exchange when caching is used ## What changes were proposed in this pull request? We should provide customized canonicalize plan for `InMemoryRelation` and `InMemoryTableScanExec`. Otherwise, we can wrongly treat two different cached plans as same result. It causes wrongly reused exchange then. For a test query like this: ```scala val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() val group1 = cached.groupBy("x").agg(min(col("y")) as "value") val group2 = cached.groupBy("x").agg(min(col("z")) as "value") group1.union(group2) ``` Canonicalized plans before: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` You can find that they have the canonicalized plans are the same, although we use different columns in two `InMemoryTableScan`s. Canonicalized plan after: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#2] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20831 from viirya/SPARK-23614. --- .../execution/columnar/InMemoryRelation.scala | 10 ++++++++++ .../columnar/InMemoryTableScanExec.scala | 19 +++++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ .../spark/sql/execution/ExchangeSuite.scala | 7 +++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 22e16913d4da9..2579046e30708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan @@ -68,6 +69,15 @@ case class InMemoryRelation( override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), + storageLevel = StorageLevel.NONE, + child = child.canonicalized, + tableName = None)( + _cachedColumnBuffers, + sizeInBytesStats, + statsOfPlanToCache) + override def producedAttributes: AttributeSet = outputSet @transient val partitionStatistics = new PartitionStatistics(output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index a93e8a1ad954d..e73e1378d52e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -38,6 +38,11 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + override def vectorTypes: Option[Seq[String]] = Option(Seq.fill(attributes.length)( if (!conf.offHeapColumnVectorEnabled) { @@ -169,11 +174,13 @@ case class InMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Keeps relation's partition statistics because we don't serialize relation. + private val stats = relation.partitionStatistics + private def statsFor(a: Attribute) = stats.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - @transient val buildFilter: PartialFunction[Expression, Expression] = { + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -213,14 +220,14 @@ case class InMemoryTableScanExec( l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } - val partitionFilters: Seq[Expression] = { + lazy val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = filter.map( BindReferences.bindReference( _, - relation.partitionStatistics.schema, + stats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -243,7 +250,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = relation.partitionStatistics.schema + val schema = stats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 49c59cf695dc1..9b745befcb611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1446,8 +1446,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val data = Seq(("a", null)) checkDataset(data.toDS(), data: _*) } + + test("SPARK-23614: Union produces incorrect results when caching is used") { + val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() + val group1 = cached.groupBy("x").agg(min(col("y")) as "value") + val group2 = cached.groupBy("x").agg(min(col("z")) as "value") + checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) + } } +case class TestDataUnion(x: Int, y: Int, z: Int) + case class SingleData(id: Int) case class DoubleData(id: Int, val1: String) case class TripleData(id: Int, val1: String, val2: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 697d7e6520713..bde2de5b39fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -125,4 +125,11 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) } } + + test("SPARK-23614: Fix incorrect reuse exchange when caching is used") { + val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache() + val projection1 = cached.select("_1", "_2").queryExecution.executedPlan + val projection2 = cached.select("_1", "_3").queryExecution.executedPlan + assert(!projection1.sameResult(projection2)) + } } From 5fa438471110afbf4e2174df449ac79e292501f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 23 Mar 2018 13:59:21 +0800 Subject: [PATCH 203/593] [SPARK-23361][YARN] Allow AM to restart after initial tokens expire. Currently, the Spark AM relies on the initial set of tokens created by the submission client to be able to talk to HDFS and other services that require delegation tokens. This means that after those tokens expire, a new AM will fail to start (e.g. when there is an application failure and re-attempts are enabled). This PR makes it so that the first thing the AM does when the user provides a principal and keytab is to create new delegation tokens for use. This makes sure that the AM can be started irrespective of how old the original token set is. It also allows all of the token management to be done by the AM - there is no need for the submission client to set configuration values to tell the AM when to renew tokens. Note that even though in this case the AM will not be using the delegation tokens created by the submission client, those tokens still need to be provided to YARN, since they are used to do log aggregation. To be able to re-use the code in the AMCredentialRenewal for the above purposes, I refactored that class a bit so that it can fetch tokens into a pre-defined UGI, insted of always logging in. Another issue with re-attempts is that, after the fix that allows the AM to restart correctly, new executors would get confused about when to update credentials, because the credential updater used the update time initially set up by the submission code. This could make the executor fail to update credentials in time, since that value would be very out of date in the situation described in the bug. To fix that, I changed the YARN code to use the new RPC-based mechanism for distributing tokens to executors. This allowed the old credential updater code to be removed, and a lot of code in the renewer to be simplified. I also made two currently hardcoded values (the renewal time ratio, and the retry wait) configurable; while this probably never needs to be set by anyone in a production environment, it helps with testing; that's also why they're not documented. Tested on real cluster with a specially crafted application to test this functionality: checked proper access to HDFS, Hive and HBase in cluster mode with token renewal on and AM restarts. Tested things still work in client mode too. Author: Marcelo Vanzin Closes #20657 from vanzin/SPARK-23361. --- .../scala/org/apache/spark/SparkConf.scala | 12 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 32 +- .../CoarseGrainedExecutorBackend.scala | 12 - .../spark/internal/config/package.scala | 12 + .../MesosHadoopDelegationTokenManager.scala | 11 +- .../spark/deploy/yarn/ApplicationMaster.scala | 117 +++---- .../org/apache/spark/deploy/yarn/Client.scala | 102 ++---- .../deploy/yarn/YarnSparkHadoopUtil.scala | 20 -- .../org/apache/spark/deploy/yarn/config.scala | 25 -- .../yarn/security/AMCredentialRenewer.scala | 291 +++++++----------- .../yarn/security/CredentialUpdater.scala | 131 -------- .../YARNHadoopDelegationTokenManager.scala | 9 +- .../cluster/YarnClientSchedulerBackend.scala | 9 +- .../cluster/YarnSchedulerBackend.scala | 10 +- ...ARNHadoopDelegationTokenManagerSuite.scala | 7 +- .../apache/spark/streaming/Checkpoint.scala | 3 - 16 files changed, 238 insertions(+), 565 deletions(-) delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f53b2bed74c6e..129956e9f9ffa 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -603,13 +603,15 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.rpc", "2.0", "Not used anymore."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", "Please use the new blacklisting options, spark.blacklist.*"), - DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), - DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used anymore"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used anymore"), DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", - "Not used any more. Please use spark.shuffle.service.index.cache.size") + "Not used anymore. Please use spark.shuffle.service.index.cache.size"), + DeprecatedConfig("spark.yarn.credentials.file.retention.count", "2.4.0", "Not used anymore."), + DeprecatedConfig("spark.yarn.credentials.file.retention.days", "2.4.0", "Not used anymore.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -748,7 +750,7 @@ private[spark] object SparkConf extends Logging { } if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) { logWarning( - s"The configuration key $key is not supported any more " + + s"The configuration key $key is not supported anymore " + s"because Spark doesn't use Akka since 2.0") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 177295fb7af0f..8353e64a619cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -146,7 +147,8 @@ class SparkHadoopUtil extends Logging { private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) val creds = deserialize(tokens) - logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + logInfo("Updating delegation tokens for current user.") + logDebug(s"Adding/updating delegation tokens ${dumpTokens(creds)}") addCurrentUserCredentials(creds) } @@ -321,19 +323,6 @@ class SparkHadoopUtil extends Logging { } } - /** - * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. - * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. - */ - private[spark] def getConfBypassingFSCache( - hadoopConf: Configuration, - scheme: String): Configuration = { - val newConf = new Configuration(hadoopConf) - val confKey = s"fs.${scheme}.impl.disable.cache" - newConf.setBoolean(confKey, true) - newConf - } - /** * Dump the credentials' tokens to string values. * @@ -447,16 +436,17 @@ object SparkHadoopUtil { def get: SparkHadoopUtil = instance /** - * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date - * when a given fraction of the duration until the expiration date has passed. - * Formula: current time + (fraction * (time until expiration)) + * Given an expiration date for the current set of credentials, calculate the time when new + * credentials should be created. + * * @param expirationDate Drop-dead expiration date - * @param fraction fraction of the time until expiration return - * @return Date when the fraction of the time until expiration has passed + * @param conf Spark configuration + * @return Timestamp when new credentials should be created. */ - private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + private[spark] def nextCredentialRenewalTime(expirationDate: Long, conf: SparkConf): Long = { val ct = System.currentTimeMillis - (ct + (fraction * (expirationDate - ct))).toLong + val ratio = conf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) + (ct + (ratio * (expirationDate - ct))).toLong } /** diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9b62e4b1b7150..48d3630abd1f9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -213,13 +213,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } - if (driverConf.contains("spark.yarn.credentials.file")) { - logInfo("Will periodically update credentials from: " + - driverConf.get("spark.yarn.credentials.file")) - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("startCredentialUpdater", classOf[SparkConf]) - .invoke(null, driverConf) - } cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) @@ -234,11 +227,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() - if (driverConf.contains("spark.yarn.credentials.file")) { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("stopCredentialUpdater") - .invoke(null) - } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a313ad0554a3a..407545aa4a47a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -525,4 +525,16 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1g") + private[spark] val CREDENTIALS_RENEWAL_INTERVAL_RATIO = + ConfigBuilder("spark.security.credentials.renewalRatio") + .doc("Ratio of the credential's expiration time when Spark should fetch new credentials.") + .doubleConf + .createWithDefault(0.75d) + + private[spark] val CREDENTIALS_RENEWAL_RETRY_WAIT = + ConfigBuilder("spark.security.credentials.retryWait") + .doc("How long to wait before retrying to fetch new credentials after a failure.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1h") + } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 7165bfae18a5e..a1bf4f0c048fe 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils @@ -63,7 +64,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") @@ -104,8 +105,10 @@ private[spark] class MesosHadoopDelegationTokenManager( } catch { case e: Exception => // Log the error and try to write new tokens back in an hour - logWarning("Couldn't broadcast tokens, trying again in an hour", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning( + s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e) + credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS) return } scheduleRenewal(this) @@ -135,7 +138,7 @@ private[spark] class MesosHadoopDelegationTokenManager( "related configurations in the target services.") currTime } else { - SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf) } logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6e35d23def6f0..d04989e138f83 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -29,7 +29,6 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -41,7 +40,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} +import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -79,42 +78,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) - private val ugi = { - val original = UserGroupInformation.getCurrentUser() - - // If a principal and keytab were provided, log in to kerberos, and set up a thread to - // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL - // of the TGT, use a configuration to define how often to check that a relogin is necessary. - // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. - val principal = sparkConf.get(PRINCIPAL).orNull - val keytab = sparkConf.get(KEYTAB).orNull - if (principal != null && keytab != null) { - UserGroupInformation.loginUserFromKeytab(principal, keytab) - - val renewer = new Thread() { - override def run(): Unit = Utils.tryLogNonFatalError { - while (true) { - TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) - UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() - } - } + private val userClassLoader = { + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + + if (isClusterMode) { + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } - renewer.setName("am-kerberos-renewer") - renewer.setDaemon(true) - renewer.start() - - // Transfer the original user's tokens to the new user, since that's needed to connect to - // YARN. It also copies over any delegation tokens that might have been created by the - // client, which will then be transferred over when starting executors (until new ones - // are created by the periodic task). - val newUser = UserGroupInformation.getCurrentUser() - SparkHadoopUtil.get.transferCredentials(original, newUser) - newUser } else { - SparkHadoopUtil.get.createSparkUser() + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } } + private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ => + new AMCredentialRenewer(sparkConf, yarnConf) + } + + private val ugi = credentialRenewer match { + case Some(cr) => + // Set the context class loader so that the token renewer has access to jars distributed + // by the user. + val currentLoader = Thread.currentThread().getContextClassLoader() + Thread.currentThread().setContextClassLoader(userClassLoader) + try { + cr.start() + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + + case _ => + SparkHadoopUtil.get.createSparkUser() + } + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic @@ -148,23 +148,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // A flag to check whether user has initialized spark context @volatile private var registered = false - private val userClassLoader = { - val classpath = Client.getUserClasspath(sparkConf) - val urls = classpath.map { entry => - new URL("file:" + new File(entry.getPath()).getAbsolutePath()) - } - - if (isClusterMode) { - if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { - new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } - // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() @@ -189,8 +172,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. private val sparkContextPromise = Promise[SparkContext]() - private var credentialRenewer: AMCredentialRenewer = _ - // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. @@ -316,31 +297,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - // If the credentials file config is present, we must periodically renew tokens. So create - // a new AMDelegationTokenRenewer - if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { - // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the - // classloader so that main jar and secondary jars could be used by AMCredentialRenewer. - val credentialRenewerThread = new Thread { - setName("AMCredentialRenewerStarter") - setContextClassLoader(userClassLoader) - - override def run(): Unit = { - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - yarnConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - - val credentialRenewer = - new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) - credentialRenewer.scheduleLoginFromKeytab() - } - } - - credentialRenewerThread.start() - credentialRenewerThread.join() - } - if (isClusterMode) { runDriver() } else { @@ -409,9 +365,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logDebug("shutting down user thread") userClassThread.interrupt() } - if (!inShutdown && credentialRenewer != null) { - credentialRenewer.stop() - credentialRenewer = null + if (!inShutdown) { + credentialRenewer.foreach(_.stop()) } } } @@ -468,6 +423,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends securityMgr, localResources) + credentialRenewer.foreach(_.setDriverRef(driverRef)) + // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), // the allocator is ready to service requests. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 28087dee831d1..5763c3dbc5a8a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -93,11 +93,21 @@ private[spark] class Client( private val distCacheMgr = new ClientDistributedCacheManager() - private var loginFromKeytab = false - private var principal: String = null - private var keytab: String = null - private var credentials: Credentials = null - private var amKeytabFileName: String = null + private val principal = sparkConf.get(PRINCIPAL).orNull + private val keytab = sparkConf.get(KEYTAB).orNull + private val loginFromKeytab = principal != null + private val amKeytabFileName: String = { + require((principal == null) == (keytab == null), + "Both principal and keytab must be defined, or neither.") + if (loginFromKeytab) { + logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + new File(keytab).getName() + "-" + UUID.randomUUID().toString + } else { + null + } + } private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -120,11 +130,6 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) } @@ -145,9 +150,6 @@ private[spark] class Client( var appId: ApplicationId = null try { launcherBackend.connect() - // Setup the credentials before doing anything else, - // so we have don't have issues at any point. - setupCredentials() yarnClient.init(hadoopConf) yarnClient.start() @@ -288,8 +290,26 @@ private[spark] class Client( appContext } - /** Set up security tokens for launching our ApplicationMaster container. */ + /** + * Set up security tokens for launching our ApplicationMaster container. + * + * This method will obtain delegation tokens from all the registered providers, and set them in + * the AM's launch context. + */ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + credentialManager.obtainDelegationTokens(hadoopConf, credentials) + + // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid + // that for regular users, since in those case the user already has access to the TGT, + // and adding delegation tokens could lead to expired or cancelled tokens being used + // later, as reported in SPARK-15754. + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) amContainer.setTokens(ByteBuffer.wrap(dob.getData)) @@ -384,36 +404,6 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = destDir.getFileSystem(hadoopConf) - // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) - - if (credentials != null) { - // Add credentials to current user's UGI, so that following operations don't need to use the - // Kerberos tgt to get delegations again in the client side. - val currentUser = UserGroupInformation.getCurrentUser() - if (SparkHadoopUtil.get.isProxyUser(currentUser)) { - currentUser.addCredentials(credentials) - } - logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) - } - - // If we use principal and keytab to login, also credentials can be renewed some time - // after current time, we should pass the next renewal and updating time to credential - // renewer and updater. - if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() && - nearestTimeOfNextRenewal != Long.MaxValue) { - - // Valid renewal time is 75% of next renewal time, and the valid update time will be - // slightly later then renewal time (80% of next renewal time). This is to make sure - // credentials are renewed and updated before expired. - val currTime = System.currentTimeMillis() - val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime - val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime - - sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong) - sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong) - } - // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. @@ -793,11 +783,6 @@ private[spark] class Client( populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (loginFromKeytab) { - val credentialsFile = "credentials-" + UUID.randomUUID().toString - sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) - logInfo(s"Credentials file set to: $credentialsFile") - } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -1014,25 +999,6 @@ private[spark] class Client( amContainer } - def setupCredentials(): Unit = { - loginFromKeytab = sparkConf.contains(PRINCIPAL.key) - if (loginFromKeytab) { - principal = sparkConf.get(PRINCIPAL).get - keytab = sparkConf.get(KEYTAB).orNull - - require(keytab != null, "Keytab must be specified when principal is specified.") - logInfo("Attempting to login to the Kerberos" + - s" using principal: $principal and keytab: $keytab") - val f = new File(keytab) - // Generate a file name that can be used for the keytab file, that does not conflict - // with any user file. - amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(PRINCIPAL.key, principal) - } - // Defensive copy of the credentials - credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials) - } - /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f406fabd61860..8eda6cb1277c5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.CredentialUpdater import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils @@ -38,8 +37,6 @@ import org.apache.spark.util.Utils object YarnSparkHadoopUtil { - private var credentialUpdater: CredentialUpdater = _ - // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. @@ -206,21 +203,4 @@ object YarnSparkHadoopUtil { filesystemsToAccess + stagingFS } - def startCredentialUpdater(sparkConf: SparkConf): Unit = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) - credentialUpdater.start() - } - - def stopCredentialUpdater(): Unit = { - if (credentialUpdater != null) { - credentialUpdater.stop() - credentialUpdater = null - } - } - } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 3ba3ae5ab4401..1a99b3bd57672 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -231,16 +231,6 @@ package object config { /* Security configuration. */ - private[spark] val CREDENTIAL_FILE_MAX_COUNT = - ConfigBuilder("spark.yarn.credentials.file.retention.count") - .intConf - .createWithDefault(5) - - private[spark] val CREDENTIALS_FILE_MAX_RETENTION = - ConfigBuilder("spark.yarn.credentials.file.retention.days") - .intConf - .createWithDefault(5) - private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + "fs.defaultFS does not need to be listed here.") @@ -271,11 +261,6 @@ package object config { /* Private configs. */ - private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") - .internal() - .stringConf - .createWithDefault(null) - // Internal config to propagate the location of the user's jar to the driver/executors private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") .internal() @@ -329,16 +314,6 @@ package object config { .stringConf .createOptional - private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - - private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1m") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index eaf2cff111a49..bc8d47dbd54c6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -18,221 +18,160 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils /** - * The following methods are primarily meant to make sure long-running apps like Spark - * Streaming apps can run without interruption while accessing secured services. The - * scheduleLoginFromKeytab method is called on the AM to get the new credentials. - * This method wakes up a thread that logs into the KDC - * once 75% of the renewal interval of the original credentials used for the container - * has elapsed. It then obtains new credentials and writes them to HDFS in a - * pre-specified location - the prefix of which is specified in the sparkConf by - * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc. - * - each update goes to a new file, with a monotonically increasing suffix), also the - * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater. - * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed. + * A manager tasked with periodically updating delegation tokens needed by the application. * - * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is - * called once 80% of the validity of the original credentials has elapsed. At that time the - * executor finds the credentials file with the latest timestamp and checks if it has read those - * credentials before (by keeping track of the suffix of the last file it read). If a new file has - * appeared, it will read the credentials and update the currently running UGI with it. This - * process happens again once 80% of the validity of this has expired. + * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run + * without interruption while accessing secured services. It periodically logs in to the KDC with + * user-provided credentials, and contacts all the configured secure services to obtain delegation + * tokens to be distributed to the rest of the application. + * + * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API + * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is + * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet + * needed. The check period can be overridden in the configuration. + * + * New delegation tokens are created once 75% of the renewal interval of the original tokens has + * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. + * The driver is tasked with distributing the tokens to other processes that might need them. */ private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { + hadoopConf: Configuration) extends Logging { - private var lastCredentialsFileSuffix = 0 + private val principal = sparkConf.get(PRINCIPAL).get + private val keytab = sparkConf.get(KEYTAB).get + private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - private val credentialRenewerThread: ScheduledExecutorService = + private val renewalExecutor: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - private val hadoopUtil = SparkHadoopUtil.get + private val driverRef = new AtomicReference[RpcEndpointRef]() - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) - private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT) - private val freshHadoopConf = - hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) + private val renewalTask = new Runnable() { + override def run(): Unit = { + updateTokensTask() + } + } - @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + def setDriverRef(ref: RpcEndpointRef): Unit = { + driverRef.set(ref) + } /** - * Schedule a login from the keytab and principal set using the --principal and --keytab - * arguments to spark-submit. This login happens only when the credentials of the current user - * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from - * SparkConf to do the login. This method is a no-op in non-YARN mode. + * Start the token renewer. Upon start, the renewer will: * + * - log in the configured user, and set up a task to keep that user's ticket renewed + * - obtain delegation tokens from all available providers + * - schedule a periodic task to update the tokens when needed. + * + * @return The newly logged in user. */ - private[spark] def scheduleLoginFromKeytab(): Unit = { - val principal = sparkConf.get(PRINCIPAL).get - val keytab = sparkConf.get(KEYTAB).get - - /** - * Schedule re-login and creation of new credentials. If credentials have already expired, this - * method will synchronously create new ones. - */ - def scheduleRenewal(runnable: Runnable): Unit = { - // Run now! - val remainingTime = timeOfNextRenewal - System.currentTimeMillis() - if (remainingTime <= 0) { - logInfo("Credentials have expired, creating new ones now.") - runnable.run() - } else { - logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + def start(): UserGroupInformation = { + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + val ugi = doLogin() + + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() } } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) - // This thread periodically runs on the AM to update the credentials on HDFS. - val credentialRenewerRunnable = - new Runnable { - override def run(): Unit = { - try { - writeNewCredentialsToHDFS(principal, keytab) - cleanupOldFiles() - } catch { - case e: Exception => - // Log the error and try to write new tokens back in an hour - logWarning("Failed to write out new credentials to HDFS, will try again in an " + - "hour! If this happens too often tasks will fail.", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) - return - } - scheduleRenewal(this) - } - } - // Schedule update of credentials. This handles the case of updating the credentials right now - // as well, since the renewal interval will be 0, and the thread will get scheduled - // immediately. - scheduleRenewal(credentialRenewerRunnable) + val creds = obtainTokensAndScheduleRenewal(ugi) + ugi.addCredentials(creds) + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. Explicitly avoid overwriting tokens that already exist in the current user's + // credentials, since those were freshly obtained above (see SPARK-23361). + val existing = ugi.getCredentials() + existing.mergeAll(originalCreds) + ugi.addCredentials(existing) + + ugi + } + + def stop(): Unit = { + renewalExecutor.shutdown() + } + + private def scheduleRenewal(delay: Long): Unit = { + val _delay = math.max(0, delay) + logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") + renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) } - // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At - // least numFilesToKeep files are kept for safety - private def cleanupOldFiles(): Unit = { - import scala.concurrent.duration._ + /** + * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself + * to fetch the next set of tokens when needed. + */ + private def updateTokensTask(): Unit = { try { - val remoteFs = FileSystem.get(freshHadoopConf) - val credentialsPath = new Path(credentialsFile) - val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .dropRight(numFilesToKeep) - .takeWhile(_.getModificationTime < thresholdTime) - .foreach(x => remoteFs.delete(x.getPath, true)) + val freshUGI = doLogin() + val creds = obtainTokensAndScheduleRenewal(freshUGI) + val tokens = SparkHadoopUtil.get.serialize(creds) + + val driver = driverRef.get() + if (driver != null) { + logInfo("Updating delegation tokens.") + driver.send(UpdateDelegationTokens(tokens)) + } else { + // This shouldn't really happen, since the driver should register way before tokens expire + // (or the AM should time out the application). + logWarning("Delegation tokens close to expiration but no driver has registered yet.") + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) + } } catch { - // Such errors are not fatal, so don't throw. Make sure they are logged though case e: Exception => - logWarning("Error while attempting to cleanup old credentials. If you are seeing many " + - "such warnings there may be an issue with your HDFS cluster.", e) + val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + + " If this happens too often tasks will fail.", e) + scheduleRenewal(delay) } } - private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = { - // Keytab is copied by YARN to the working directory of the AM, so full path is - // not needed. - - // HACK: - // HDFS will not issue new delegation tokens, if the Credentials object - // passed in already has tokens for that FS even if the tokens are expired (it really only - // checks if there are tokens for the service, and not if they are valid). So the only real - // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the current user's Credentials. - // So: - // - we login as a different user and get the UGI - // - use that UGI to get the tokens (see doAs block below) - // - copy the tokens over to the current user's credentials (this will overwrite the tokens - // in the current user's Credentials object for this FS). - // The login to KDC happens each time new tokens are required, but this is rare enough to not - // have to worry about (like once every day or so). This makes this code clearer than having - // to login and then relogin every time (the HDFS API may not relogin since we don't use this - // UGI directly for HDFS communication. - logInfo(s"Attempting to login to KDC using principal: $principal") - val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - logInfo("Successfully logged into KDC.") - val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(credentialsFile) - val dst = credentialsPath.getParent - var nearestNextRenewalTime = Long.MaxValue - keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { - // Get a copy of the credentials - override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainDelegationTokens( - freshHadoopConf, - tempCreds) - null + /** + * Obtain new delegation tokens from the available providers. Schedules a new task to fetch + * new tokens before the new set expires. + * + * @return Credentials containing the new tokens. + */ + private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { + ugi.doAs(new PrivilegedExceptionAction[Credentials]() { + override def run(): Credentials = { + val creds = new Credentials() + val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds) + + val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) - + System.currentTimeMillis() + scheduleRenewal(timeToWait) + creds } }) - - val currTime = System.currentTimeMillis() - val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) { - // If next renewal time is earlier than current time, we set next renewal time to current - // time, this will trigger next renewal immediately. Also set next update time to current - // time. There still has a gap between token renewal and update will potentially introduce - // issue. - logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " + - s"current time ($currTime), which is unexpected, please check your credential renewal " + - "related configurations in the target services.") - timeOfNextRenewal = currTime - currTime - } else { - // Next valid renewal time is about 75% of credential renewal time, and update time is - // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) - } - - // Add the temp credentials back to the original ones. - UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(freshHadoopConf) - // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM - // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file - // and update the lastCredentialsFileSuffix. - if (lastCredentialsFileSuffix == 0) { - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.foreach { status => - lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) - } - } - val nextSuffix = lastCredentialsFileSuffix + 1 - - val tokenPathStr = - credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - nextSuffix - val tokenPath = new Path(tokenPathStr) - val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - - logInfo("Writing out delegation tokens to " + tempTokenPath.toString) - val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) - logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") - remoteFs.rename(tempTokenPath, tokenPath) - logInfo("Delegation token file rename complete.") - lastCredentialsFileSuffix = nextSuffix } - def stop(): Unit = { - credentialRenewerThread.shutdown() + private def doLogin(): UserGroupInformation = { + logInfo(s"Attempting to login to KDC using principal: $principal") + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + ugi } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala deleted file mode 100644 index fe173dffc22a8..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.concurrent.{Executors, TimeUnit} - -import scala.util.control.NonFatal - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -private[spark] class CredentialUpdater( - sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { - - @volatile private var lastCredentialsFileSuffix = 0 - - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val freshHadoopConf = - SparkHadoopUtil.get.getConfBypassingFSCache( - hadoopConf, new Path(credentialsFile).toUri.getScheme) - - private val credentialUpdater = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) - - // This thread wakes up and picks up new credentials from HDFS, if any. - private val credentialUpdaterRunnable = - new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) - } - - /** Start the credential updater task */ - def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) - val remainingTime = startTime - System.currentTimeMillis() - if (remainingTime <= 0) { - credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) - } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") - credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) - } - } - - private def updateCredentialsIfRequired(): Unit = { - val timeToNextUpdate = try { - val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(freshHadoopConf) - SparkHadoopUtil.get.listFilesSorted( - remoteFs, credentialsFilePath.getParent, - credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.map { credentialsStatus => - val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) - if (suffix > lastCredentialsFileSuffix) { - logInfo("Reading new credentials from " + credentialsStatus.getPath) - val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) - lastCredentialsFileSuffix = suffix - UserGroupInformation.getCurrentUser.addCredentials(newCredentials) - logInfo("Credentials updated from credentials file.") - - val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis()) - if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime - } else { - // If current credential file is older than expected, sleep 1 hour and check again. - TimeUnit.HOURS.toMillis(1) - } - }.getOrElse { - // Wait for 1 minute to check again if there's no credential file currently - TimeUnit.MINUTES.toMillis(1) - } - } catch { - // Since the file may get deleted while we are reading it, catch the Exception and come - // back in an hour to try again - case NonFatal(e) => - logWarning("Error while trying to update credentials, will try again in 1 hour", e) - TimeUnit.HOURS.toMillis(1) - } - - logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") - credentialUpdater.schedule( - credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) - } - - private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { - val stream = remoteFs.open(tokenPath) - try { - val newCredentials = new Credentials() - newCredentials.readTokenStorageStream(stream) - newCredentials - } finally { - stream.close() - } - } - - private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = { - val name = credentialsPath.getName - val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - val slice = name.substring(0, index) - val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - name.substring(last2index + 1, index).toLong - } - - def stop(): Unit = { - credentialUpdater.shutdown() - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 163cfb4eb8624..d4eeb6bbcf886 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -22,11 +22,11 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -37,11 +37,10 @@ import org.apache.spark.util.Utils */ private[yarn] class YARNHadoopDelegationTokenManager( sparkConf: SparkConf, - hadoopConf: Configuration, - fileSystems: Configuration => Set[FileSystem]) extends Logging { + hadoopConf: Configuration) extends Logging { - private val delegationTokenManager = - new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) // public for testing val credentialProviders = getCredentialProviders diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0c6206eebe41d..06e54a2eaf95a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -62,12 +62,6 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() - // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver - // reads the credentials from HDFS, just like the executors and updates its own credentials - // cache. - if (conf.contains("spark.yarn.credentials.file")) { - YarnSparkHadoopUtil.startCredentialUpdater(conf) - } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -153,7 +147,6 @@ private[spark] class YarnClientSchedulerBackend( client.reportLauncherState(SparkAppHandle.State.FINISHED) super.stop() - YarnSparkHadoopUtil.stopCredentialUpdater() client.stop() logInfo("Stopped") } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bb615c36cd97f..63bea3e7a5003 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,9 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -70,6 +72,7 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -263,8 +266,13 @@ private[spark] abstract class YarnSchedulerBackend( logWarning(s"Requesting driver to remove executor $executorId for reason $reason") driverEndpoint.send(r) } - } + case u @ UpdateDelegationTokens(tokens) => + // Add the tokens to the current user and send a message to the scheduler so that it + // notifies all registered executors of the new tokens. + SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf) + driverEndpoint.send(u) + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 3c7cdc0f1dab8..9fa749b14c98c 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.security.Credentials import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { private var credentialManager: YARNHadoopDelegationTokenManager = null @@ -36,11 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers } test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - + credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) credentialManager.credentialProviders.get("yarn-test") should not be (None) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..3703a87cdb9ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -57,9 +57,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", - "spark.yarn.credentials.file", - "spark.yarn.credentials.renewalTime", - "spark.yarn.credentials.updateTime", "spark.ui.filters", "spark.mesos.driver.frameworkId") From 92e952557dbd8a170d66d615e25c6c6a8399dd43 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Mar 2018 21:01:07 +0900 Subject: [PATCH 204/593] [MINOR][R] Fix R lint failure ## What changes were proposed in this pull request? The lint failure bugged me: ```R R/SQLContext.R:715:97: style: Trailing whitespace is superfluous. #' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to ^ tests/fulltests/test_streaming.R:239:45: style: Commas should always have a space after. expect_equal(times[order(times$eventTime),][1, 2], 2) ^ lintr checks failed. ``` and I actually saw https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.6-ubuntu-test/500/console too. If I understood correctly, there is a try about moving to Unbuntu one. ## How was this patch tested? Manually tested by `./dev/lint-r`: ``` ... lintr checks passed. ``` Author: hyukjinkwon Closes #20879 from HyukjinKwon/minor-r-lint. --- R/pkg/R/SQLContext.R | 2 +- R/pkg/tests/fulltests/test_streaming.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ebec0ce3d1920..429dd5d565492 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -712,7 +712,7 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to #' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it #' uses the default value, session local timezone. #' @return SparkDataFrame diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index a354d50c6b54e..bfb1a046490ec 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -236,7 +236,7 @@ test_that("Watermark", { times <- collect(sql("SELECT * FROM times")) # looks like write timing can affect the first bucket; but it should be t - expect_equal(times[order(times$eventTime),][1, 2], 2) + expect_equal(times[order(times$eventTime), ][1, 2], 2) stopQuery(q) unlink(parquetPath) From 6ac4fba69290e1c7de2c0a5863f224981dedb919 Mon Sep 17 00:00:00 2001 From: arucard21 Date: Fri, 23 Mar 2018 21:02:34 +0900 Subject: [PATCH 205/593] [SPARK-23769][CORE] Remove comments that unnecessarily disable Scalastyle check ## What changes were proposed in this pull request? We re-enabled the Scalastyle checker on a line of code. It was previously disabled, but it does not violate any of the rules. So there's no reason to disable the Scalastyle checker here. ## How was this patch tested? We tested this by running `build/mvn scalastyle:check` after removing the comments that disable the checker. This check passed with no errors or warnings for Spark Core ``` [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Core 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- scalastyle-maven-plugin:1.0.0:check (default-cli) spark-core_2.11 --- Saving to outputFile=/spark/core/target/scalastyle-output.xml Processed 485 file(s) Found 0 errors Found 0 warnings Found 0 infos ``` We did not run all tests (with `dev/run-tests`) since this Scalastyle check seemed sufficient. ## Co-contributors: chialun-yeh Hrayo712 vpourquie Author: arucard21 Closes #20880 from arucard21/scalastyle_util. --- .../org/apache/spark/storage/BlockReplicationPolicy.scala | 4 +--- .../main/scala/org/apache/spark/util/CompletionIterator.scala | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index 353eac60df171..0bacc34cdfd90 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -54,10 +54,9 @@ trait BlockReplicationPolicy { } object BlockReplicationUtils { - // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see + * minimizing space usage. Please see * here. * * @param n total number of indices @@ -65,7 +64,6 @@ object BlockReplicationUtils { * @param r random number generator * @return list of m random unique indices */ - // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 31d230d0fec8e..21acaa95c5645 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -22,9 +22,7 @@ package org.apache.spark.util * through all the elements. */ private[spark] -// scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { -// scalastyle:on private[this] var completed = false def next(): A = sub.next() From 8b56f16640fc4156aa7bd529c54469d27635b951 Mon Sep 17 00:00:00 2001 From: bag_of_tricks Date: Fri, 23 Mar 2018 10:36:23 -0700 Subject: [PATCH 206/593] [SPARK-23759][UI] Unable to bind Spark UI to specific host name / IP ## What changes were proposed in this pull request? Fixes SPARK-23759 by moving connector.start() after connector.setHost() Problem was created due connector.setHost(hostName) call was after connector.start() ## How was this patch tested? Patch was tested after build and deployment. This patch requires SPARK_LOCAL_IP environment variable to be set on spark-env.sh Author: bag_of_tricks Closes #20883 from felixalbani/SPARK-23759. --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0adeb4058b6e4..0e8a6307de6a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -343,12 +343,13 @@ private[spark] object JettyUtils extends Logging { -1, connectionFactories: _*) connector.setPort(port) - connector.start() + connector.setHost(hostName) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) - connector.setHost(hostName) + + connector.start() // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 From cb43bbe13606673349511829fd71d1f34fc39c45 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 23 Mar 2018 11:42:40 -0700 Subject: [PATCH 207/593] [SPARK-21685][PYTHON][ML] PySpark Params isSet state should not change after transform ## What changes were proposed in this pull request? Currently when a PySpark Model is transformed, default params that have not been explicitly set are then set on the Java side on the call to `wrapper._transfer_values_to_java`. This incorrectly changes the state of the Param as it should still be marked as a default value only. ## How was this patch tested? Added a new test to verify that when transferring Params to Java, default params have their state preserved. Author: Bryan Cutler Closes #18982 from BryanCutler/pyspark-ml-param-to-java-defaults-SPARK-21685. --- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- python/pyspark/ml/wrapper.py | 13 ++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index fd45fd00b270b..080119959a4e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -369,7 +369,7 @@ def test_property(self): raise RuntimeError("Test property to raise error when invoked") -class ParamTests(PySparkTestCase): +class ParamTests(SparkSessionTestCase): def test_copy_new_parent(self): testParams = TestParams() @@ -514,6 +514,24 @@ def test_logistic_regression_check_thresholds(self): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + @staticmethod def check_params(test_self, py_stage, check_params_exist=True): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 5061f6434794a..d325633195ddb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -118,11 +118,18 @@ def _transfer_params_to_java(self): """ Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap() + pair_defaults = [] for param in self.params: - if param in paramMap: - pair = self._make_java_param_pair(param, paramMap[param]) + if self.isSet(param): + pair = self._make_java_param_pair(param, self._paramMap[param]) self._java_obj.set(pair) + if self.hasDefault(param): + pair = self._make_java_param_pair(param, self._defaultParamMap[param]) + pair_defaults.append(pair) + if len(pair_defaults) > 0: + sc = SparkContext._active_spark_context + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) + self._java_obj.setDefault(pair_defaults_seq) def _transfer_param_map_to_java(self, pyParamMap): """ From 95c03cbd27cea2255d9d748f9a84a0a38e54594d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Mar 2018 11:56:17 -0700 Subject: [PATCH 208/593] [SPARK-23783][SPARK-11239][ML] Add PMML export to Spark ML pipelines ## What changes were proposed in this pull request? Adds PMML export support to Spark ML pipelines in the style of Spark's DataSource API to allow library authors to add their own model export formats. Includes a specific implementation for Spark ML linear regression PMML export. In addition to adding PMML to reach parity with our current MLlib implementation, this approach will allow other libraries & formats (like PFA) to implement and export models with a unified API. ## How was this patch tested? Basic unit test. Author: Holden Karau Author: Holden Karau Closes #19876 from holdenk/SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans-r2. --- .../org.apache.spark.ml.util.MLFormatRegister | 2 + .../ml/regression/LinearRegression.scala | 70 ++++--- .../org/apache/spark/ml/util/ReadWrite.scala | 173 +++++++++++++++++- .../org.apache.spark.ml.util.MLFormatRegister | 3 + .../ml/regression/LinearRegressionSuite.scala | 27 ++- .../spark/ml/util/PMMLReadWriteTest.scala | 55 ++++++ .../org/apache/spark/ml/util/PMMLUtils.scala | 43 +++++ .../apache/spark/ml/util/ReadWriteSuite.scala | 132 +++++++++++++ 8 files changed, 474 insertions(+), 31 deletions(-) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..5e5484fd8784d --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,2 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92510154d500e..f67d9d831f327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{PipelineStage, PredictorParams} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionParams with GeneralMLWritable { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) @@ -710,7 +711,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -718,7 +719,50 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) +} + +/** A writer for LinearRegression that handles the "internal" (or default) format */ +private class InternalLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector, scale: Double) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[LinearRegressionModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients, scale + val data = Data(instance.intercept, instance.coefficients, instance.scale) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for LinearRegression that handles the "pmml" format */ +private class PMMLLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val sc = sparkSession.sparkContext + // Construct the MLLib model which knows how to write to PMML. + val instance = stage.asInstanceOf[LinearRegressionModel] + val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) + // Save PMML + oldModel.toPMML(sc, path) + } } @Since("1.6.0") @@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") override def load(path: String): LinearRegressionModel = super.load(path) - /** [[MLWriter]] instance for [[LinearRegressionModel]] */ - private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends MLWriter with Logging { - - private case class Data(intercept: Double, coefficients: Vector, scale: Double) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients, scale - val data = Data(instance.intercept, instance.coefficients, instance.scale) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a616907800969..7edcd498678cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,9 +18,11 @@ package org.apache.spark.ml.util import java.io.IOException -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.json4s._ @@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Abstract class to be implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.4.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLFormatRegister extends MLWriterFormat { + /** + * The string that represents the format that this format provider uses. This is, along with + * stageName, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def format(): String = + * "pmml" + * }}} + * Indicates that this format is capable of saving a pmml model. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def format(): String + + /** + * The string that represents the stage type that this writer supports. This is, along with + * format, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def stageName(): String = + * "org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own PMML model. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def stageName(): String + + private[ml] def shortName(): String = s"${format()}+${stageName()}" +} + +/** + * Abstract class for utility classes that can save ML instances in Spark's internal format. */ @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { @@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + /** * Map to store extra options for this writer. */ @@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + @Since("1.6.0") + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + @Since("1.6.0") + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +@InterfaceStability.Unstable +@Since("2.4.0") +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. "pmml", "internal", or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { - shouldOverwrite = true + @Since("2.4.0") + def format(source: String): this.type = { + this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.4.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String): Unit = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) + val stageName = stage.getClass.getName + val targetName = s"$source+$stageName" + val formats = serviceLoader.asScala.toList + val shortNames = formats.map(_.shortName()) + val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => + Try(loader.loadClass(source)) match { + case Success(writer) => + // Found the ML writer using the fully qualified path + writer + case Failure(error) => + throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) + } + case head :: Nil => + head.getClass + case _ => + // Multiple sources + throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") + } + if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + writer.write(path, sparkSession, optionMap, stage) + } else { + throw new SparkException(s"ML source $source is not a valid MLWriterFormat") + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) @@ -162,6 +306,19 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * Trait for classes that provide `GeneralMLWriter`. + */ +@Since("2.4.0") +@InterfaceStability.Unstable +trait GeneralMLWritable extends MLWritable { + /** + * Returns an `MLWriter` instance for this ML instance. + */ + @Since("2.4.0") + override def write: GeneralMLWriter +} + /** * :: DeveloperApi :: * diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..100ef2545418f --- /dev/null +++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,3 @@ +org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 +org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 +org.apache.spark.ml.util.FakeLinearRegressionWriterWithName diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 9b19f63eba1bd..90ceb7dee38f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.ml.regression +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random +import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} + import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { + +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { import testImplicits._ @@ -1052,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 0000000000000..d2c4832b12bac --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.write.format("pmml").save(path) + intercept[IOException] { + instance.write.format("pmml").save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark")) + checkModelData(pmmlModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala new file mode 100644 index 0000000000000..dbdc69f95d841 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.util + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala new file mode 100644 index 0000000000000..f4c1f0bdb32cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.sql.{DataFrame, SparkSession} + +class FakeLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + +class FakeLinearRegressionWriterWithName extends MLFormatRegister { + override def format(): String = "fakeWithName" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + + +class DuplicateLinearRegressionWriter1 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class DuplicateLinearRegressionWriter2 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class ReadWriteSuite extends MLTest { + + import testImplicits._ + + private val seed: Int = 42 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0), + xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF() + } + + test("unsupported/non existent export formats") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + // Does not exist with a long class name + val thrownDNE = intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + assert(thrownDNE.getMessage(). + contains("Could not load requested format")) + + // Does not exist with a short name + val thrownDNEShort = intercept[SparkException] { + model.write.format("boop").save("boop") + } + assert(thrownDNEShort.getMessage(). + contains("Could not load requested format")) + + // Check with a valid class that is not a writer format. + val thrownInvalid = intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + assert(thrownInvalid.getMessage() + .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat")) + } + + test("invalid paths fail") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("pmml").save("") + } + assert(thrown.getMessage().contains("Can not create a Path from an empty string")) + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name") + } + assert(thrown.getMessage().contains("Fake writer doesn't write")) + val thrownWithName = intercept[Exception] { + model.write.format("fakeWithName").save("name") + } + assert(thrownWithName.getMessage().contains("Fake writer doesn't write")) + } + + test("duplicate format raises error") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("dupe").save("dupepanda") + } + assert(thrown.getMessage().contains("Multiple writers found for")) + } +} From a33655348c4066d9c1d8ad2055aadfbc892ba7fd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 23 Mar 2018 15:58:48 -0700 Subject: [PATCH 209/593] [SPARK-23615][ML][PYSPARK] Add maxDF Parameter to Python CountVectorizer ## What changes were proposed in this pull request? The maxDF parameter is for filtering out frequently occurring terms. This param was recently added to the Scala CountVectorizer and needs to be added to Python also. ## How was this patch tested? add test Author: Huaxin Gao Closes #20777 from huaxingao/spark-23615. --- .../spark/ml/feature/CountVectorizer.scala | 20 +++++----- python/pyspark/ml/feature.py | 40 ++++++++++++++----- python/pyspark/ml/tests.py | 25 ++++++++++++ 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 60a4f918790a3..9e0ed437e7bfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,19 +70,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit def getMinDF: Double = $(minDF) /** - * Specifies the maximum number of different documents a term must appear in to be included - * in the vocabulary. - * If this is an integer greater than or equal to 1, this specifies the number of documents - * the term must appear in; if this is a double in [0,1), then this specifies the fraction of - * documents. + * Specifies the maximum number of different documents a term could appear in to be included + * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an + * integer greater than or equal to 1, this specifies the maximum number of documents the term + * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of + * documents the term could appear in. * - * Default: (2^64^) - 1 + * Default: (2^63^) - 1 * @group param */ val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an integer >= 1," + + " this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum fraction of" + + " documents the term could appear in.", ParamValidators.gtEq(0.0)) /** @group getParam */ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a1ceb7f02da8b..fcb0dfc563720 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -422,6 +422,14 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): " If this is an integer >= 1, this specifies the number of documents the term must" + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + " Default 1.0", typeConverter=TypeConverters.toFloat) + maxDF = Param( + Params._dummy(), "maxDF", "Specifies the maximum number of" + + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an" + + " integer >= 1, this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum" + + " fraction of documents the term could appear in." + + " Default (2^63) - 1", typeConverter=TypeConverters.toFloat) vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) @@ -433,7 +441,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): def __init__(self, *args): super(_CountVectorizerParams, self).__init__(*args) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False) @since("1.6.0") def getMinTF(self): @@ -449,6 +457,13 @@ def getMinDF(self): """ return self.getOrDefault(self.minDF) + @since("2.4.0") + def getMaxDF(self): + """ + Gets the value of maxDF or its default value. + """ + return self.getOrDefault(self.maxDF) + @since("1.6.0") def getVocabSize(self): """ @@ -513,11 +528,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav """ @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None,outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", @@ -527,11 +542,11 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None, outputCol=None) Set the params for the CountVectorizer """ kwargs = self._input_kwargs @@ -551,6 +566,13 @@ def setMinDF(self, value): """ return self._set(minDF=value) + @since("2.4.0") + def setMaxDF(self, value): + """ + Sets the value of :py:attr:`maxDF`. + """ + return self._set(maxDF=value) + @since("1.6.0") def setVocabSize(self, value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 080119959a4e8..cf1ffa181ecec 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -697,6 +697,31 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", outputCol="features", minTF=2) From 816a5496ba4caac438f70400f72bb10bfcc02418 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Sat, 24 Mar 2018 18:21:01 -0700 Subject: [PATCH 210/593] [SPARK-23788][SS] Fix race in StreamingQuerySuite ## What changes were proposed in this pull request? The serializability test uses the same MemoryStream instance for 3 different queries. If any of those queries ask it to commit before the others have run, the rest will see empty dataframes. This can fail the test if q3 is affected. We should use one instance per query instead. ## How was this patch tested? Existing unit test. If I move q2.processAllAvailable() before starting q3, the test always fails without the fix. Author: Jose Torres Closes #20896 from jose-torres/fixrace. --- .../spark/sql/streaming/StreamingQuerySuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ebc9a87b23f84..08749b49997e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -550,22 +550,22 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi .start() } - val input = MemoryStream[Int] - val q1 = startQuery(input.toDS, "stream_serializable_test_1") - val q2 = startQuery(input.toDS.map { i => + val input = MemoryStream[Int] :: MemoryStream[Int] :: MemoryStream[Int] :: Nil + val q1 = startQuery(input(0).toDS, "stream_serializable_test_1") + val q2 = startQuery(input(1).toDS.map { i => // Emulate that `StreamingQuery` get captured with normal usage unintentionally. // It should not fail the query. q1 i }, "stream_serializable_test_2") - val q3 = startQuery(input.toDS.map { i => + val q3 = startQuery(input(2).toDS.map { i => // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear // error message. q1.explain() i }, "stream_serializable_test_3") try { - input.addData(1) + input.foreach(_.addData(1)) // q2 should not fail since it doesn't use `q1` in the closure q2.processAllAvailable() From 5f653d4f7c84e6147cd323cd650da65e0381ebe8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 25 Mar 2018 09:18:26 -0700 Subject: [PATCH 211/593] [SPARK-23167][SQL] Add TPCDS queries v2.7 in TPCDSQuerySuite ## What changes were proposed in this pull request? This pr added TPCDS v2.7 (latest) queries in `TPCDSQuerySuite` because the current `TPCDSQuerySuite` tests older one (v1.4) and some queries are different from v1.4 and v2.7. Since the original v2.7 queries have the syntaxes that Spark cannot parse, I changed these queries in a following way: - [date] + 14 days -> date + `INTERVAL` 14 days - [column name] as "30 days" -> [column name] as \`30 days\` - Fix some syntax errors, e.g., missing brackets ## How was this patch tested? Added tests in `TPCDSQuerySuite`. Author: Takeshi Yamamuro Closes #20343 from maropu/TPCDSV2_7. --- .../src/test/resources/tpcds-v2.7.0/q10a.sql | 69 ++++++ .../src/test/resources/tpcds-v2.7.0/q11.sql | 84 +++++++ .../src/test/resources/tpcds-v2.7.0/q12.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q14.sql | 135 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q14a.sql | 215 ++++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q18a.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q20.sql | 19 ++ .../src/test/resources/tpcds-v2.7.0/q22.sql | 15 ++ .../src/test/resources/tpcds-v2.7.0/q22a.sql | 94 ++++++++ .../src/test/resources/tpcds-v2.7.0/q24.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q27a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q34.sql | 37 +++ .../src/test/resources/tpcds-v2.7.0/q35.sql | 65 ++++++ .../src/test/resources/tpcds-v2.7.0/q35a.sql | 62 +++++ .../src/test/resources/tpcds-v2.7.0/q36a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q47.sql | 64 ++++++ .../src/test/resources/tpcds-v2.7.0/q49.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q51a.sql | 103 +++++++++ .../src/test/resources/tpcds-v2.7.0/q57.sql | 57 +++++ .../src/test/resources/tpcds-v2.7.0/q5a.sql | 158 +++++++++++++ .../src/test/resources/tpcds-v2.7.0/q6.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q64.sql | 111 +++++++++ .../src/test/resources/tpcds-v2.7.0/q67a.sql | 208 +++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q70a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q72.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q74.sql | 60 +++++ .../src/test/resources/tpcds-v2.7.0/q75.sql | 78 +++++++ .../src/test/resources/tpcds-v2.7.0/q77a.sql | 121 ++++++++++ .../src/test/resources/tpcds-v2.7.0/q78.sql | 75 ++++++ .../src/test/resources/tpcds-v2.7.0/q80a.sql | 147 ++++++++++++ .../src/test/resources/tpcds-v2.7.0/q86a.sql | 61 +++++ .../src/test/resources/tpcds-v2.7.0/q98.sql | 22 ++ .../apache/spark/sql/TPCDSQuerySuite.scala | 38 +++- 33 files changed, 2691 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q11.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q12.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q20.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q22.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q24.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q34.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q35.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q47.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q49.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q57.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q6.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q64.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q72.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q74.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q75.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q78.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q98.sql diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql new file mode 100644 index 0000000000000..50e521567eb3a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql @@ -0,0 +1,69 @@ +-- This is a new query in TPCDS v2.7 +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from + customer c,customer_address ca,customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and ca_county in ('Walker County', 'Richland County', 'Gaines County', 'Douglas County', 'Dona Ana County') + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales,date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) + and exists ( + select * + from ( + select + ws_bill_customer_sk as customer_sk, + d_year, + d_moy + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3 + union all + select + cs_ship_customer_sk as customer_sk, + d_year, + d_moy + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) x + where c.c_customer_sk = customer_sk) +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql new file mode 100755 index 0000000000000..97bed33721742 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql @@ -0,0 +1,84 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT + -- select list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END +ORDER BY + -- order-by list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql new file mode 100755 index 0000000000000..7a6fafd22428a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql @@ -0,0 +1,23 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql new file mode 100644 index 0000000000000..b2ca3ddaf2baf --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql @@ -0,0 +1,135 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14a.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1998 AND 1998 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1998 AND 1998 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1998 AND 1998 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x) +select + * +from ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + 1 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year, + ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales)) last_year +where + this_year.i_brand_id = last_year.i_brand_id + and this_year.i_class_id = last_year.i_class_id + and this_year.i_category_id = last_year.i_category_id +order by + this_year.channel, + this_year.i_brand_id, + this_year.i_class_id, + this_year.i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql new file mode 100644 index 0000000000000..bfa70fe62d8d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql @@ -0,0 +1,215 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14b.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1999 AND 1999 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1999 AND 1999 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1999 AND 1999 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1999 and 2001 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x), +results AS ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales) sum_sales, + sum(number_sales) number_sales + from ( + select + 'store' channel, + i_brand_id,i_class_id, + i_category_id, + sum(ss_quantity*ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales) + union all + select + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity*cs_list_price) sales, + count(*) number_sales + from + catalog_sales, item, date_dim + where + cs_item_sk in (select ss_item_sk from cross_items) + and cs_item_sk = i_item_sk + and cs_sold_date_sk = d_date_sk + and d_year = 1998+2 + and d_moy = 11 + group by + i_brand_id,i_class_id,i_category_id + having + sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales) + union all + select + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity*ws_list_price) sales, + count(*) number_sales + from + web_sales, item, date_dim + where + ws_item_sk in (select ss_item_sk from cross_items) + and ws_item_sk = i_item_sk + and ws_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales)) y + group by + channel, + i_brand_id, + i_class_id, + i_category_id) +select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales +from ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales + from + results + union + select + channel, + i_brand_id, + i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id, + i_class_id + union + select + channel, + i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id + union + select + channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel + union + select + null as channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results) z +order by + channel, + i_brand_id, + i_class_id, + i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql new file mode 100644 index 0000000000000..2201a302ab352 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql @@ -0,0 +1,133 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + cast(cs_quantity as decimal(12,2)) agg1, + cast(cs_list_price as decimal(12,2)) agg2, + cast(cs_coupon_amt as decimal(12,2)) agg3, + cast(cs_sales_price as decimal(12,2)) agg4, + cast(cs_net_profit as decimal(12,2)) agg5, + cast(c_birth_year as decimal(12,2)) agg6, + cast(cd1.cd_dep_count as decimal(12,2)) agg7 + from + catalog_sales, customer_demographics cd1, customer_demographics cd2, customer, + customer_address, date_dim, item + where + cs_sold_date_sk = d_date_sk + and cs_item_sk = i_item_sk + and cs_bill_cdemo_sk = cd1.cd_demo_sk + and cs_bill_customer_sk = c_customer_sk + and cd1.cd_gender = 'M' + and cd1.cd_education_status = 'College' + and c_current_cdemo_sk = cd2.cd_demo_sk + and c_current_addr_sk = ca_address_sk + and c_birth_month in (9,5,12,4,1,10) + and d_year = 2001 + and ca_state in ('ND','WI','AL','NC','OK','MS','TN')) +select + i_item_id, + ca_country, + ca_state, + ca_county, + agg1, + agg2, + agg3, + agg4, + agg5, + agg6, + agg7 +from ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state, + ca_county + union all + select + i_item_id, + ca_country, + ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state + union all + select + i_item_id, + ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id, + ca_country + union all + select + i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results) foo +order by + ca_country, + ca_state, + ca_county, + i_item_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql new file mode 100755 index 0000000000000..34d46b1394d8f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql new file mode 100755 index 0000000000000..e7bea0804f162 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql @@ -0,0 +1,15 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + -- q22 in TPCDS v1.4 had a condition below: + -- AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql new file mode 100644 index 0000000000000..c886e6271511b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql @@ -0,0 +1,94 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh + from + inventory, date_dim, item, warehouse + where + inv_date_sk = d_date_sk + and inv_item_sk = i_item_sk + and inv_warehouse_sk = w_warehouse_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_product_name, + i_brand, + i_class, + i_category), +results_rollup as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class, + i_category + union all + select + i_product_name, + i_brand, + i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class + union all + select + i_product_name, + i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand + union all + select + i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name + union all + select + null i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results) +select + i_product_name, + i_brand, + i_class, + i_category, + qoh +from + results_rollup +order by + qoh, + i_product_name, + i_brand, + i_class, + i_category +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql new file mode 100755 index 0000000000000..92d64bc7eba78 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql @@ -0,0 +1,40 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_current_addr_sk = ca_address_sk -- This condition did not exist in TPCDS v1.4 + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) +-- no order-by exists in q24a of TPCDS v1.4 +ORDER BY + c_last_name, + c_first_name, + s_store_name diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql new file mode 100644 index 0000000000000..c70a2420e8387 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + s_state, 0 as g_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + from + store_sales, customer_demographics, date_dim, store, item + where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_store_sk = s_store_sk + and ss_cdemo_sk = cd_demo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and d_year = 1998 + and s_state in ('TN','TN', 'TN', 'TN', 'TN', 'TN')) +select + i_item_id, + s_state, + g_state, + agg1, + agg2, + agg3, + agg4 +from ( + select + i_item_id, + s_state, + 0 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results + group by + i_item_id, + s_state + union all + select + i_item_id, + NULL AS s_state, + 1 AS g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as s_state, + 1 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results) foo +order by + i_item_id, + s_state +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql new file mode 100755 index 0000000000000..bbede62acc9a7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql @@ -0,0 +1,37 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag DESC, + ss_ticket_number -- This order-by condition did not exist in TPCDS v1.4 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql new file mode 100755 index 0000000000000..27116a563d5c6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql @@ -0,0 +1,65 @@ +SELECT + -- select list of q35 in TPCDS v1.4 is below: + -- ca_state, + -- cd_gender, + -- cd_marital_status, + -- count(*) cnt1, + -- min(cd_dep_count), + -- max(cd_dep_count), + -- avg(cd_dep_count), + -- cd_dep_employed_count, + -- count(*) cnt2, + -- min(cd_dep_employed_count), + -- max(cd_dep_employed_count), + -- avg(cd_dep_employed_count), + -- cd_dep_college_count, + -- count(*) cnt3, + -- min(cd_dep_college_count), + -- max(cd_dep_college_count), + -- avg(cd_dep_college_count) + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql new file mode 100644 index 0000000000000..1c1463e44777f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql @@ -0,0 +1,62 @@ +-- This is a new query in TPCDS v2.7 +select + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +from + customer c, customer_address ca, customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales, date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) + and exists ( + select * + from ( + select ws_bill_customer_sk customsk + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4 + union all + select cs_ship_customer_sk customsk + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) x + where x.customsk = c.c_customer_sk) +group by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql new file mode 100644 index 0000000000000..9d98f32add508 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as ss_net_profit, + sum(ss_ext_sales_price) as ss_ext_sales_price, + sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin, + i_category, + i_class, + 0 as g_category, + 0 as g_class + from + store_sales, + date_dim d1, + item, + store + where + d1.d_year = 2001 + and d1.d_date_sk = ss_sold_date_sk + and i_item_sk = ss_item_sk + and s_store_sk = ss_store_sk + and s_state in ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') + group by + i_category, + i_class), + results_rollup as ( + select + gross_margin, + i_category, + i_class, + 0 as t_category, + 0 as t_class, + 0 as lochierarchy + from + results + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + i_category, NULL AS i_class, + 0 as t_category, + 1 as t_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + NULL AS i_category, + NULL AS i_class, + 1 as t_category, + 1 as t_class, + 2 as lochierarchy + from + results) +select + gross_margin, + i_category, + i_class, + lochierarchy, + rank() over ( + partition by lochierarchy, case when t_class = 0 then i_category end + order by gross_margin asc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql new file mode 100755 index 0000000000000..9f7ee457ea45f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql @@ -0,0 +1,64 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + -- q47 in TPCDS v1.4 had more columns below: + -- v1.i_brand, + -- v1.s_store_name, + -- v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql new file mode 100755 index 0000000000000..e8061bde4159e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql @@ -0,0 +1,133 @@ +-- The first SELECT query below is different from q49 of TPCDS v1.4 +SELECT + channel, + item, + return_ratio, + return_rank, + currency_rank +FROM ( + SELECT + 'web' as channel, + in_web.item, + in_web.return_ratio, + in_web.return_rank, + in_web.currency_rank + FROM + (SELECT + item, + return_ratio, + currency_ratio, + rank() over (ORDER BY return_ratio) AS return_rank, + rank() over (ORDER BY currency_ratio) AS currency_rank + FROM ( + SELECT + ws.ws_item_sk AS item, + CAST(SUM(COALESCE(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_quantity, 0)) AS DECIMAL(15, 4)) AS return_ratio, + CAST(SUM(COALESCE(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_net_paid, 0)) AS DECIMAL(15, 4)) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND ws.ws_item_sk = wr.wr_item_sk), + date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY + ws.ws_item_sk) + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY + -- order-by list of q49 in TPCDS v1.4 is below: + -- 1, 4, 5 + 1, 4, 5, 2 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql new file mode 100644 index 0000000000000..b8cbbbc8ef7d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql @@ -0,0 +1,103 @@ +-- This is a new query in TPCDS v2.7 +WITH web_tv as ( + select + ws_item_sk item_sk, + d_date, + sum(ws_sales_price) sumws, + row_number() over (partition by ws_item_sk order by d_date) rk + from + web_sales, date_dim + where + ws_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ws_item_sk is not NULL + group by + ws_item_sk, d_date), +web_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumws, + sum(v2.sumws) cume_sales + from + web_tv v1, web_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumws), +store_tv as ( + select + ss_item_sk item_sk, + d_date, + sum(ss_sales_price) sumss, + row_number() over (partition by ss_item_sk order by d_date) rk + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_item_sk is not NULL + group by ss_item_sk, d_date), +store_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumss, + sum(v2.sumss) cume_sales + from + store_tv v1, store_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumss), +v as ( + select + item_sk, + d_date, + web_sales, + store_sales, + row_number() over (partition by item_sk order by d_date) rk + from ( + select + case when web.item_sk is not null + then web.item_sk + else store.item_sk end item_sk, + case when web.d_date is not null + then web.d_date + else store.d_date end d_date, + web.cume_sales web_sales, + store.cume_sales store_sales + from + web_v1 web full outer join store_v1 store + on (web.item_sk = store.item_sk and web.d_date = store.d_date))) +select * +from ( + select + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales, + max(v2.web_sales) web_cumulative, + max(v2.store_sales) store_cumulative + from + v v1, v v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales) x +where + web_cumulative > store_cumulative +order by + item_sk, + d_date +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql new file mode 100755 index 0000000000000..ccefaac3c12ca --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql @@ -0,0 +1,57 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + -- q57 in TPCDS v1.4 had a column below: + -- v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql new file mode 100644 index 0000000000000..42bcf59c2aeb1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql @@ -0,0 +1,158 @@ +-- This is a new query in TPCDS v2.7 +with ssr as( + select + s_store_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ss_store_sk as store_sk, + ss_sold_date_sk as date_sk, + ss_ext_sales_price as sales_price, + ss_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + store_sales + union all + select + sr_store_sk as store_sk, + sr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + sr_return_amt as return_amt, + sr_net_loss as net_loss + from + store_returns) salesreturns, + date_dim, + store + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and store_sk = s_store_sk + group by + s_store_id), +csr as ( + select + cp_catalog_page_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + cs_catalog_page_sk as page_sk, + cs_sold_date_sk as date_sk, + cs_ext_sales_price as sales_price, + cs_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from catalog_sales + union all + select + cr_catalog_page_sk as page_sk, + cr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + cr_return_amount as return_amt, + cr_net_loss as net_loss + from catalog_returns) salesreturns, + date_dim, + catalog_page + where + date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and page_sk = cp_catalog_page_sk + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ws_web_site_sk as wsr_web_site_sk, + ws_sold_date_sk as date_sk, + ws_ext_sales_price as sales_price, + ws_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + web_sales + union all + select + ws_web_site_sk as wsr_web_site_sk, + wr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + wr_return_amt as return_amt, + wr_net_loss as net_loss + from + web_returns + left outer join web_sales on ( + wr_item_sk = ws_item_sk and wr_order_number = ws_order_number) + ) salesreturns, + date_dim, + web_site + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and wsr_web_site_sk = web_site_sk + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || s_store_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || cp_catalog_page_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + wsr) x + group by + channel, id) +select + channel, id, sales, returns, profit +from ( + select channel, id, sales, returns, profit + from results + union + select channel, null as id, sum(sales), sum(returns), sum(profit) + from results + group by channel + union + select null as channel, null as id, sum(sales), sum(returns), sum(profit) + from results) foo + order by channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql new file mode 100755 index 0000000000000..c0bfa40ad44a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql @@ -0,0 +1,23 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +-- order-by list of q6 in TPCDS v1.4 is below: +-- order by cnt +order by cnt, a.ca_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql new file mode 100755 index 0000000000000..cdcd8486b363d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql @@ -0,0 +1,111 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY + i_product_name, + i_item_sk, + s_store_name, + s_zip, + ad1.ca_street_number, + ad1.ca_street_name, + ad1.ca_city, + ad1.ca_zip, + ad2.ca_street_number, + ad2.ca_street_name, + ad2.ca_city, + ad2.ca_zip, + d1.d_year, + d2.d_year, + d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY + cs1.product_name, + cs1.store_name, + cs2.cnt, + -- The two columns below are newly added in TPCDS v2.7 + cs1.s1, + cs2.s1 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql new file mode 100644 index 0000000000000..70a14043bbb3d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql @@ -0,0 +1,208 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + from + store_sales, date_dim, store, item + where + ss_sold_date_sk=d_date_sk + and ss_item_sk=i_item_sk + and ss_store_sk = s_store_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id), +results_rollup as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales + from + results + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year + union all + select + i_category, + i_class, + i_brand, + i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name + union all + select + i_category, + i_class, + i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand + union all + select + i_category, + i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class + union all + select + i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from results + group by + i_category + union all + select + null i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results) +select + * +from ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() over (partition by i_category order by sumsales desc) rk + from results_rollup) dw2 +where + rk <= 100 +order by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rk +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql new file mode 100644 index 0000000000000..4aec9c7fd1fd6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as total_sum, + s_state ,s_county, + 0 as gstate, + 0 as g_county + from + store_sales, date_dim d1, store + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_state in ( + select s_state + from ( + select + s_state as s_state, + rank() over (partition by s_state order by sum(ss_net_profit) desc) as ranking + from store_sales, store, date_dim + where d_month_seq between 1212 and 1212 + 11 + and d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + group by s_state) tmp1 + where ranking <= 5) + group by + s_state, s_county), +results_rollup as ( + select + total_sum, + s_state, + s_county, + 0 as g_state, + 0 as g_county, + 0 as lochierarchy + from results + union + select + sum(total_sum) as total_sum,s_state, + NULL as s_county, + 0 as g_state, + 1 as g_county, + 1 as lochierarchy + from results + group by s_state + union + select + sum(total_sum) as total_sum, + NULL as s_state, + NULL as s_county, + 1 as g_state, + 1 as g_county, + 2 as lochierarchy + from results) +select + total_sum, + s_state, + s_county, + lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_county = 0 then s_state end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then s_state end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql new file mode 100755 index 0000000000000..066d6a587e917 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql @@ -0,0 +1,40 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +-- q72 in TPCDS v1.4 had conditions below: +-- WHERE d1.d_week_seq = d2.d_week_seq +-- AND inv_quantity_on_hand < cs_quantity +-- AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) +-- AND hd_buy_potential = '>10000' +-- AND d1.d_year = 1999 +-- AND hd_buy_potential = '>10000' +-- AND cd_marital_status = 'D' +-- AND d1.d_year = 1999 +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > d1.d_date + INTERVAL 5 days + AND hd_buy_potential = '1001-5000' + AND d1.d_year = 2001 + AND cd_marital_status = 'M' +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql new file mode 100755 index 0000000000000..94a0063b36c0c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql @@ -0,0 +1,60 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +-- order-by list of q74 in TPCDS v1.4 is below: +-- ORDER BY 1, 1, 1 +ORDER BY 2, 1, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql new file mode 100755 index 0000000000000..ae5dc97ef2317 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql @@ -0,0 +1,78 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY + sales_cnt_diff, + sales_amt_diff -- This order-by condition did not exist in TPCDS v1.4 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql new file mode 100644 index 0000000000000..fc69c43470f1e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql @@ -0,0 +1,121 @@ +-- This is a new query in TPCDS v2.7 +with ss as ( + select + s_store_sk, + sum(ss_ext_sales_price) as sales, + sum(ss_net_profit) as profit + from + store_sales, date_dim, store + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + group by + s_store_sk), +sr as ( + select + s_store_sk, + sum(sr_return_amt) as returns, + sum(sr_net_loss) as profit_loss + from + store_returns, date_dim, store + where + sr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and sr_store_sk = s_store_sk + group by + s_store_sk), +cs as ( + select + cs_call_center_sk, + sum(cs_ext_sales_price) as sales, + sum(cs_net_profit) as profit + from + catalog_sales, + date_dim + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + group by + cs_call_center_sk), + cr as ( + select + sum(cr_return_amount) as returns, + sum(cr_net_loss) as profit_loss + from catalog_returns, + date_dim + where + cr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days)), +ws as ( select wp_web_page_sk, + sum(ws_ext_sales_price) as sales, + sum(ws_net_profit) as profit + from web_sales, + date_dim, + web_page + where ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_page_sk = wp_web_page_sk + group by wp_web_page_sk), + wr as + (select wp_web_page_sk, + sum(wr_return_amt) as returns, + sum(wr_net_loss) as profit_loss + from web_returns, + date_dim, + web_page + where wr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and wr_web_page_sk = wp_web_page_sk + group by wp_web_page_sk) + , + results as + (select channel + , id + , sum(sales) as sales + , sum(returns) as returns + , sum(profit) as profit + from + (select 'store channel' as channel + , ss.s_store_sk as id + , sales + , coalesce(returns, 0) as returns + , (profit - coalesce(profit_loss,0)) as profit + from ss left join sr + on ss.s_store_sk = sr.s_store_sk + union all + select 'catalog channel' as channel + , cs_call_center_sk as id + , sales + , returns + , (profit - profit_loss) as profit + from cs + , cr + union all + select 'web channel' as channel + , ws.wp_web_page_sk as id + , sales + , coalesce(returns, 0) returns + , (profit - coalesce(profit_loss,0)) as profit + from ws left join wr + on ws.wp_web_page_sk = wr.wp_web_page_sk + ) x + group by channel, id ) + + select * + from ( + select channel, id, sales, returns, profit from results + union + select channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results group by channel + union + select NULL AS channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results +) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql new file mode 100755 index 0000000000000..d03d8af77174c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql @@ -0,0 +1,75 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + -- order-by list of q78 in TPCDS v1.4 is below: + -- ratio, + -- ss_qty DESC, ss_wc DESC, ss_sp DESC, + -- other_chan_qty, + -- other_chan_wholesale_cost, + -- other_chan_sales_price, + -- round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) + ss_sold_year, + ss_item_sk, + ss_customer_sk, + ss_qty desc, + ss_wc desc, + ss_sp desc, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + ratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql new file mode 100644 index 0000000000000..686e03ba2a6d0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql @@ -0,0 +1,147 @@ +-- This is a new query in TPCDS v2.7 +with ssr as ( + select + s_store_id as store_id, + sum(ss_ext_sales_price) as sales, + sum(coalesce(sr_return_amt, 0)) as returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit + from + store_sales left outer join store_returns on ( + ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number), + date_dim, + store, + item, + promotion + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + and ss_item_sk = i_item_sk + and i_current_price > 50 + and ss_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + s_store_id), +csr as ( + select + cp_catalog_page_id as catalog_page_id, + sum(cs_ext_sales_price) as sales, + sum(coalesce(cr_return_amount, 0)) as returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit + from + catalog_sales left outer join catalog_returns on + (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number), + date_dim, + catalog_page, + item, + promotion + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and cs_catalog_page_sk = cp_catalog_page_sk + and cs_item_sk = i_item_sk + and i_current_price > 50 + and cs_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(ws_ext_sales_price) as sales, + sum(coalesce(wr_return_amt, 0)) as returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit + from + web_sales left outer join web_returns on ( + ws_item_sk = wr_item_sk and ws_order_number = wr_order_number), + date_dim, + web_site, + item, + promotion + where + ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_site_sk = web_site_sk + and ws_item_sk = i_item_sk + and i_current_price > 50 + and ws_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || store_id as id, + sales, + returns, + profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || catalog_page_id as id, + sales, + returns, + profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + profit + from + wsr) x + group by + channel, id) +select + channel, + id, + sales, + returns, + profit +from ( + select + channel, + id, + sales, + returns, + profit + from + results + union + select + channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results + group by + channel + union + select + NULL AS channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql new file mode 100644 index 0000000000000..fff76b08d4ba0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql @@ -0,0 +1,61 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ws_net_paid) as total_sum, + i_category, i_class, + 0 as g_category, + 0 as g_class + from + web_sales, date_dim d1, item + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ws_sold_date_sk + and i_item_sk = ws_item_sk + group by + i_category, i_class), +results_rollup as( + select + total_sum, + i_category, + i_class, + g_category, + g_class, + 0 as lochierarchy + from + results + union + select + sum(total_sum) as total_sum, + i_category, + NULL as i_class, + 0 as g_category, + 1 as g_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(total_sum) as total_sum, + NULL as i_category, + NULL as i_class, + 1 as g_category, + 1 as g_class, + 2 as lochierarchy + from + results) +select + total_sum, + i_category ,i_class, lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_class = 0 then i_category end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql new file mode 100755 index 0000000000000..771117add2ed2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql @@ -0,0 +1,22 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index 1a584187a06e5..bc95b4696190d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -62,7 +62,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, - |`c_email_address` STRING, `c_last_review_date` STRING) + |`c_email_address` STRING, `c_last_review_date` INT) |USING parquet """.stripMargin) @@ -88,7 +88,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `date_dim` ( - |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_date_sk` INT, `d_date_id` STRING, `d_date` DATE, |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, @@ -115,8 +115,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ - |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, - |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` DATE, + |`i_rec_end_date` DATE, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, @@ -139,8 +139,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `store` ( - |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, - |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` DATE, + |`s_rec_end_date` DATE, `s_closed_date_sk` INT, `s_store_name` STRING, |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, @@ -157,7 +157,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, - |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_quantity` INT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), @@ -225,7 +225,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, - |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` INT, |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), @@ -244,7 +244,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, - |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |`web_country` STRING, `web_gmt_offset` DECIMAL(5,2), `web_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) @@ -315,6 +315,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { """.stripMargin) } + // The TPCDS queries below are based on v1.4 val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -339,6 +340,25 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { } } + // This list only includes TPCDS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7_0 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + + tpcdsQueriesV2_7_0.foreach { name => + val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"$name-v2.7") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } + } + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries val modifiedTPCDSQueries = Seq( "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", From e4bec7cb88b9ee63f8497e3f9e0ab0bfa5d5a77c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 25 Mar 2018 16:38:49 -0700 Subject: [PATCH 212/593] [SPARK-23549][SQL] Cast to timestamp when comparing timestamp with date ## What changes were proposed in this pull request? This PR fixes an incorrect comparison in SQL between timestamp and date. This is because both of them are casted to `string` and then are compared lexicographically. This implementation shows `false` regarding this query `spark.sql("select cast('2017-03-01 00:00:00' as timestamp) between cast('2017-02-28' as date) and cast('2017-03-01' as date)").show`. This PR shows `true` for this query by casting `date("2017-03-01")` to `timestamp("2017-03-01 00:00:00")`. (Please fill in changes proposed in this fix) ## How was this patch tested? Added new UTs to `TypeCoercionSuite`. Author: Kazuaki Ishizaki Closes #20774 from kiszk/SPARK-23549. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 ++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 13 ++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 34 ++++++++++++--- .../sql-tests/inputs/predicate-functions.sql | 7 ++++ .../results/predicate-functions.sql.out | 42 ++++++++++++++++++- 6 files changed, 108 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 421e2eaf62bfb..2b393f30d1435 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1808,6 +1808,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e8669c4637d06..ec7e7761dc4c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -47,9 +47,9 @@ import org.apache.spark.sql.types._ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = - InConversion :: + InConversion(conf) :: WidenSetOperationTypes :: - PromoteStrings :: + PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: FunctionArgumentConversion :: @@ -127,7 +127,8 @@ object TypeCoercion { * is a String and the other is not. It also handles when one op is a Date and the * other is a Timestamp by making the target type to be String. */ - val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + private def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true @@ -135,11 +136,17 @@ object TypeCoercion { case (DateType, StringType) => Some(StringType) case (StringType, TimestampType) => Some(StringType) case (TimestampType, StringType) => Some(StringType) - case (TimestampType, DateType) => Some(StringType) - case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + // Cast to TimestampType when we compare DateType with TimestampType + // if conf.compareDateTimestampInTimestamp is true + // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true + case (TimestampType, DateType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + case (DateType, TimestampType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + // There is no proper decimal type we can pick, // using double type is the best we can do. // See SPARK-22469 for details. @@ -147,7 +154,7 @@ object TypeCoercion { case (s: StringType, n: DecimalType) => Some(DoubleType) case (l: StringType, r: AtomicType) if r != StringType => Some(r) - case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) case (l, r) => None } @@ -313,7 +320,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends TypeCoercionRule { + case class PromoteStrings(conf: SQLConf) extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -342,8 +349,8 @@ object TypeCoercion { p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) - if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => - val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) @@ -374,7 +381,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends TypeCoercionRule { + case class InConversion(conf: SQLConf) extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -400,7 +407,7 @@ object TypeCoercion { val rhs = sub.output val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) + findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 11864bd1b1847..9cb03b5bb6152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -479,6 +479,16 @@ object SQLConf { .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = + buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp") + .internal() + .doc("When true (default), compare Date with Timestamp after converting both sides to " + + "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + + "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + + "converting both sides to string. This config will be removed in spark 3.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -1332,6 +1342,9 @@ class SQLConf extends Serializable with Logging { def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def compareDateTimestampInTimestamp : Boolean = + getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 52a7ebdafd7c7..8ac49dc05e3cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1207,7 +1207,7 @@ class TypeCoercionSuite extends AnalysisTest { */ test("make sure rules do not fire early") { // InConversion - val inConversion = TypeCoercion.InConversion + val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, In(UnresolvedAttribute("a"), Seq(Literal(1))), In(UnresolvedAttribute("a"), Seq(Literal(1))) @@ -1251,18 +1251,40 @@ class TypeCoercionSuite extends AnalysisTest { } test("binary comparison with string promotion") { - ruleTest(PromoteStrings, + val rule = TypeCoercion.PromoteStrings(conf) + ruleTest(rule, GreaterThan(Literal("123"), Literal(1)), GreaterThan(Cast(Literal("123"), IntegerType), Literal(1))) - ruleTest(PromoteStrings, + ruleTest(rule, LessThan(Literal(true), Literal("123")), LessThan(Literal(true), Cast(Literal("123"), BooleanType))) - ruleTest(PromoteStrings, + ruleTest(rule, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) - ruleTest(PromoteStrings, + ruleTest(rule, GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), - GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), + DoubleType))) + Seq(true, false).foreach { convertToTS => + withSQLConf( + "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) { + val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) + val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) + val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) + if (convertToTS) { + // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549 + ruleTest(rule, EqualTo(date0301, timestamp0301000000), + EqualTo(Cast(date0301, TimestampType), timestamp0301000000)) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, TimestampType), timestamp0301000001)) + } else { + ruleTest(rule, LessThan(date0301, timestamp0301000000), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType))) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType))) + } + } + } } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index e99d5cef81f64..fadb4bb27fa13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -39,3 +39,10 @@ select 2.0 <= '2.2'; select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; + +-- SPARK-23549: Cast to timestamp when comparing timestamp with date +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00'); +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01'); +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01'); +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01'); +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01'); diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index d51f6d37e4b41..cf828c69af62a 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 37 -- !query 0 @@ -256,3 +256,43 @@ select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> -- !query 31 output true + + +-- !query 32 +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00') +-- !query 32 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) = to_timestamp('2017-03-01 00:00:00')):boolean> +-- !query 32 output +true + + +-- !query 33 +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01') +-- !query 33 schema +struct<(to_timestamp('2017-03-01 00:00:01') > CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 33 output +true + + +-- !query 34 +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01') +-- !query 34 schema +struct<(to_timestamp('2017-03-01 00:00:01') >= CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 34 output +true + + +-- !query 35 +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01') +-- !query 35 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) < to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 35 output +true + + +-- !query 36 +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01') +-- !query 36 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) <= to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 36 output +true From a9350d7095b79c8374fb4a06fd3f1a1a67615f6f Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 26 Mar 2018 12:42:32 +0900 Subject: [PATCH 213/593] [SPARK-23700][PYTHON] Cleanup imports in pyspark.sql ## What changes were proposed in this pull request? This cleans up unused imports, mainly from pyspark.sql module. Added a note in function.py that imports `UserDefinedFunction` only to maintain backwards compatibility for using `from pyspark.sql.function import UserDefinedFunction`. ## How was this patch tested? Existing tests and built docs. Author: Bryan Cutler Closes #20892 from BryanCutler/pyspark-cleanup-imports-SPARK-23700. --- python/pyspark/sql/column.py | 1 - python/pyspark/sql/conf.py | 1 - python/pyspark/sql/functions.py | 3 +-- python/pyspark/sql/group.py | 3 +-- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 -- python/pyspark/sql/types.py | 1 - python/pyspark/sql/udf.py | 6 ++---- python/pyspark/util.py | 2 -- 9 files changed, 5 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e05a7b33c11a7..922c7cf288f8f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,7 +16,6 @@ # import sys -import warnings import json if sys.version >= '3': diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index b82224b6194ed..db49040e17b63 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -67,7 +67,6 @@ def _checkType(self, obj, identifier): def _test(): import os import doctest - from pyspark.context import SparkContext from pyspark.sql.session import SparkSession import pyspark.sql.conf diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dff590983b4d9..a4edb1e27b599 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,7 +18,6 @@ """ A collections of builtin functions """ -import math import sys import functools import warnings @@ -28,10 +27,10 @@ from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType +# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 35cac406e0965..3505065b648f2 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,9 +19,8 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal +from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e5288636c596e..4f9b9383a5ef4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,7 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD, since, keyword_only +from pyspark import RDD, since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 07f9ac1b5aa9e..c7907aaaf1f7b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,8 +24,6 @@ else: intlike = (int, long) -from abc import ABCMeta, abstractmethod - from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5d5919e451b46..1f6534836d64a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,7 +35,6 @@ from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer -from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 24dd06c26089c..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,16 +17,14 @@ """ User-defined function related classes and functions """ -import sys -import inspect import functools import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ + to_arrow_type, to_arrow_schema from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ed1bdd0e4be83..49afc13640332 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,8 +22,6 @@ __all__ = [] -import sys - def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both From 087fb3142028d679524e22596b0ad4f74ff47e8d Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 26 Mar 2018 12:45:45 +0900 Subject: [PATCH 214/593] [SPARK-23645][MINOR][DOCS][PYTHON] Add docs RE `pandas_udf` with keyword args ## What changes were proposed in this pull request? Add documentation about the limitations of `pandas_udf` with keyword arguments and related concepts, like `functools.partial` fn objects. NOTE: intermediate commits on this PR show some of the steps that can be taken to fix some (but not all) of these pain points. ### Survey of problems we face today: (Initialize) Note: python 3.6 and spark 2.4snapshot. ``` from pyspark.sql import SparkSession import inspect, functools from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, udf spark = SparkSession.builder.getOrCreate() print(spark.version) df = spark.range(1,6).withColumn('b', col('id') * 2) def ok(a,b): return a+b ``` Using a keyword argument at the call site `b=...` (and yes, *full* stack trace below, haha): ``` ---> 14 df.withColumn('ok', pandas_udf(f=ok, returnType='bigint')('id', b='id')).show() # no kwargs TypeError: wrapper() got an unexpected keyword argument 'b' ``` Using partial with a keyword argument where the kw-arg is the first argument of the fn: *(Aside: kind of interesting that lines 15,16 work great and then 17 explodes)* ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in () 15 df.withColumn('ok', pandas_udf(f=functools.partial(ok, 7), returnType='bigint')('id')).show() 16 df.withColumn('ok', pandas_udf(f=functools.partial(ok, b=7), returnType='bigint')('id')).show() ---> 17 df.withColumn('ok', pandas_udf(f=functools.partial(ok, a=7), returnType='bigint')('id')).show() /Users/stu/ZZ/spark/python/pyspark/sql/functions.py in pandas_udf(f, returnType, functionType) 2378 return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) 2379 else: -> 2380 return _create_udf(f=f, returnType=return_type, evalType=eval_type) 2381 2382 /Users/stu/ZZ/spark/python/pyspark/sql/udf.py in _create_udf(f, returnType, evalType) 54 argspec.varargs is None: 55 raise ValueError( ---> 56 "Invalid function: 0-arg pandas_udfs are not supported. " 57 "Instead, create a 1-arg pandas_udf and ignore the arg in your function." 58 ) ValueError: Invalid function: 0-arg pandas_udfs are not supported. Instead, create a 1-arg pandas_udf and ignore the arg in your function. ``` Author: Michael (Stu) Stewart Closes #20900 from mstewart141/udfkw2. --- python/pyspark/sql/functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4edb1e27b599..ad3e37c872628 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2154,6 +2154,8 @@ def udf(f=None, returnType=StringType()): in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + .. note:: The user-defined functions do not take keyword arguments on the calling side. + :param f: python function if used as a standalone function :param returnType: the return type of the user-defined function. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. @@ -2337,6 +2339,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + + .. note:: The user-defined functions do not take keyword arguments on the calling side. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From eb48edf9ca4f4b42c63f145718696472cb6a31ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 14:01:04 +0800 Subject: [PATCH 215/593] [SPARK-23787][TESTS] Fix file download test in SparkSubmitSuite for Hadoop 2.9. This particular test assumed that Hadoop libraries did not support http as a file system. Hadoop 2.9 does, so the test failed. The test now forces a non-existent implementation for the http fs, which forces the expected error. There were also a couple of other issues in the same test: SparkSubmit arguments in the wrong order, and the wrong check later when asserting, which was being masked by the previous issues. Author: Marcelo Vanzin Closes #20895 from vanzin/SPARK-23787. --- .../spark/deploy/SparkSubmitSuite.scala | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2d0c192db4915..d86ef907b4492 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -959,25 +959,28 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) } test("force download from blacklisted schemes") { - testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) } - private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, - supportMockHttpFs: Boolean): Unit = { + private def testRemoteResources( + enableHttpFs: Boolean, + blacklistHttpFs: Boolean): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - if (supportMockHttpFs) { + if (enableHttpFs) { hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) - hadoopConf.set("fs.http.impl.disable.cache", "true") + } else { + hadoopConf.set("fs.http.impl", getClass().getName() + ".DoesNotExist") } + hadoopConf.set("fs.http.impl.disable.cache", "true") val tmpDir = Utils.createTempDir() val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) @@ -986,20 +989,19 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + val forceDownloadArgs = if (blacklistHttpFs) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + } else { + Nil + } + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", "--master", "yarn", "--deploy-mode", "client", - "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", - s"s3a://$mainResource" - ) ++ ( - if (isHttpSchemeBlacklisted) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") - } else { - Nil - } - ) + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath" + ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) @@ -1009,7 +1011,7 @@ class SparkSubmitSuite // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) - if (supportMockHttpFs) { + if (enableHttpFs && !blacklistHttpFs) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) From b30a7d28b399950953d4b112c57d4c9b9ab223e9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 12:45:45 -0700 Subject: [PATCH 216/593] [SPARK-23572][DOCS] Bring "security.md" up to date. This change basically rewrites the security documentation so that it's up to date with new features, more correct, and more complete. Because security is such an important feature, I chose to move all the relevant configuration documentation to the security page, instead of having them peppered all over the place in the configuration page. This allows an almost one-stop shop for security configuration in Spark. The only exceptions are some YARN-specific minor features which I left in the YARN page. I also re-organized the page's topics, since they didn't make a lot of sense. You had kerberos features described inside paragraphs talking about UI access control, and other oddities. It should be easier now to find information about specific Spark security features. I also enabled TOCs for both the Security and YARN pages, since that makes it easier to see what is covered. I removed most of the comments from the SecurityManager javadoc since they just replicated information in the security doc, with different levels of out-of-dateness. Author: Marcelo Vanzin Closes #20742 from vanzin/SPARK-23572. --- .gitignore | 1 + .../org/apache/spark/SecurityManager.scala | 144 +--- docs/configuration.md | 359 +--------- docs/monitoring.md | 40 +- docs/running-on-yarn.md | 203 +++--- docs/security.md | 629 +++++++++++++++--- 6 files changed, 673 insertions(+), 703 deletions(-) diff --git a/.gitignore b/.gitignore index 39085904e324c..e4c44d0590d59 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ streaming-tests.log target/ unit-tests.log work/ +docs/.jekyll-metadata # For Hive TempStatsStore/ diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index da1c89cd78901..09ec8932353a0 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -42,148 +42,10 @@ import org.apache.spark.util.Utils * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. * - * Spark currently supports authentication via a shared secret. - * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do - * authentication using the shared secret. This authentication is a basic handshake to - * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.acls.enable', 'spark.ui.view.acls' and - * 'spark.ui.view.acls.groups' control the behavior of the acls. Note that the person who - * started the application always has view access to the UI. - * - * Spark has a set of individual and group modify acls (`spark.modify.acls`) and - * (`spark.modify.acls.groups`) that controls which users and groups have permission to - * modify a single application. This would include things like killing the application. - * By default the person who started the application has modify access. For modify access - * through the UI, you must have a filter that does authentication in place for the modify - * acls to work properly. - * - * Spark also has a set of individual and group admin acls (`spark.admin.acls`) and - * (`spark.admin.acls.groups`) which is a set of users/administrators and admin groups - * who always have permission to view or modify the Spark application. - * - * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. - * - * At this point spark has multiple communication protocols that need to be secured and - * different underlying mechanisms are used depending on the protocol: - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spnego, etc. It also supports multiple different login - * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. - * Since we are using DIGEST-MD5, the shared secret is not passed on the wire - * in plaintext. - * - * We currently support SSL (https) for this communication protocol (see the details - * below). - * - * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default - * Authenticator installed in the SecurityManager to how it does the authentication - * and in this case gets the user name and password from the request. - * - * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 - * as the authentication mechanism. This means the shared secret is not passed - * over the wire in plaintext. - * Note that SASL is pluggable as to what mechanism it uses. We currently use - * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. - * Spark currently supports "auth" for the quality of protection, which means - * the connection does not support integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which - * SPARK could support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to - * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the NioBlockTransferService does asynchronous messages passing, the SASL - * authentication is a bit more complex. A ConnectionManager can be both a client - * and a Server, so for a particular connection it has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to - * match up incoming messages with connections waiting for authentication. - * The ConnectionManager tracks all the sendingConnections using the ConnectionId, - * waits for the response from the server, and does the handshake before sending - * the real message. - * - * The NettyBlockTransferService ensures that SASL authentication is performed - * synchronously prior to any other communication on a connection. This is done in - * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. - * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters - * can be used. Yarn requires a specific AmIpFilter be installed for security to work - * properly. For non-Yarn deployments, users can write a filter to go through their - * organization's normal login service. If an authentication filter is in place then the - * SparkUI can be configured to check the logged in user against the list of users who - * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters - * could be used for logging, encryption, or compression. - * - * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. - * - * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop - * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to - * support different levels of protection. See the Hadoop documentation for more details. Each - * Spark application on YARN gets a different shared secret. - * - * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user - * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web - * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That - * authentication then happens via the ResourceManager Proxy and Spark will use that to do - * authorization against the view acls. - * - * For other Spark deployments, the shared secret must be specified via the - * spark.authenticate.secret config. - * All the nodes (Master and Workers) and the applications need to have the same shared secret. - * This again is not ideal as one user could potentially affect another users application. - * This should be enhanced in the future to provide better protection. - * If the UI needs to be secure, the user needs to install a javax servlet filter to do the - * authentication. Spark will then use that user to compare against the view acls to do - * authorization. If not filter is in place the user is generally null and no authorization - * can take place. - * - * When authentication is being used, encryption can also be enabled by setting the option - * spark.authenticate.enableSaslEncryption to true. This is only supported by communication - * channels that use the network-common library, and can be used as an alternative to SSL in those - * cases. - * - * SSL can be used for encryption for certain communication channels. The user can configure the - * default SSL settings which will be used for all the supported communication protocols unless - * they are overwritten by protocol specific settings. This way the user can easily provide the - * common settings for all the protocols without disabling the ability to configure each one - * individually. - * - * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, - * denote the global configuration for all the supported protocols. In order to override the global - * configuration for the particular protocol, the properties must be overwritten in the - * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global - * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be only`fs` for - * broadcast and file server. - * - * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of - * options that can be specified. - * - * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions - * object parses Spark configuration at a given namespace and builds the common representation - * of SSL settings. SSLOptions is then used to provide protocol-specific SSLContextFactory for - * Jetty. - * - * SSL must be configured on each node and configured for each component involved in - * communication using the particular protocol. In YARN clusters, the key-store can be prepared on - * the client side then distributed and used by the executors as the part of the application - * (YARN allows the user to deploy files before the application is started). - * In standalone deployment, the user needs to provide key-stores and configuration - * options for master and workers. In this mode, the user may allow the executors to use the SSL - * settings inherited from the worker which spawned that executor. It can be accomplished by - * setting `spark.ssl.useNodeLocalConf` to `true`. + * This class implements all of the configuration related to security features described + * in the "Security" document. Please refer to that document for specific features implemented + * here. */ - private[spark] class SecurityManager( sparkConf: SparkConf, val ioEncryptionKey: Option[Array[Byte]] = None) diff --git a/docs/configuration.md b/docs/configuration.md index e7f2419cc2fa4..2eb6a77434ea6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -712,30 +712,6 @@ Apart from these, the following properties are also available, and may be useful When we fail to register to the external shuffle service, we will retry for maxAttempts times. - - spark.io.encryption.enabled - false - - Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption - be enabled when using this feature. - - - - spark.io.encryption.keySizeBits - 128 - - IO encryption key size in bits. Supported values are 128, 192 and 256. - - - - spark.io.encryption.keygen.algorithm - HmacSHA1 - - The algorithm to use when generating the IO encryption key. The supported algorithms are - described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm - Name Documentation. - - ### Spark UI @@ -893,6 +869,23 @@ Apart from these, the following properties are also available, and may be useful How many dead executors the Spark UI and status APIs remember before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark Web UI. The filter should be a + standard + javax servlet Filter. + +
    Filter parameters can also be specified in the configuration, by setting config entries + of the form spark.<class name of filter>.param.<param name>=<value> + +
    For example: +
    spark.ui.filters=com.test.filter1 +
    spark.com.test.filter1.param.name1=foo +
    spark.com.test.filter1.param.name2=bar + + ### Compression and Serialization @@ -1446,6 +1439,15 @@ Apart from these, the following properties are also available, and may be useful Duration for an RPC remote endpoint lookup operation to wait before timing out. + + spark.core.connection.ack.wait.timeout + spark.network.timeout + + How long for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + ### Scheduling @@ -1817,313 +1819,8 @@ Apart from these, the following properties are also available, and may be useful ### Security - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.acls.enablefalse - Whether Spark acls should be enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
    spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things do not work. Putting a "*" in the list means any user can have the - privilege of admin. -
    spark.admin.acls.groupsEmpty - Comma separated list of groups that have view and modify access to all Spark jobs. - This can be used if you have a set of administrators or developers who help maintain and debug - the underlying infrastructure. Putting a "*" in the list means any user in any group can have - the privilege of admin. The user groups are obtained from the instance of the groups mapping - provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user is determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. - A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider - which can be specified to resolve a list of groups for a user. - Note: This implementation supports only a Unix/Linux based environment. Windows environment is - currently not supported. However, a new platform/protocol can be supported by implementing - the trait org.apache.spark.security.GroupMappingServiceProvider. -
    spark.authenticatefalse - Whether Spark authenticates its internal connections. See - spark.authenticate.secret if not running on YARN. -
    spark.authenticate.secretNone - Set the secret key used for Spark to authenticate between components. This needs to be set if - not running on YARN and authentication is enabled. -
    spark.network.crypto.enabledfalse - Enable encryption using the commons-crypto library for RPC and block transfer service. - Requires spark.authenticate to be enabled. -
    spark.network.crypto.keyLength128 - The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. -
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 - The key factory algorithm to use when generating encryption keys. Should be one of the - algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. -
    spark.network.crypto.saslFallbacktrue - Whether to fall back to SASL authentication if authentication fails using Spark's internal - mechanism. This is useful when the application is connecting to old shuffle services that - do not support the internal Spark authentication protocol. On the server side, this can be - used to block older clients from authenticating against a new shuffle service. -
    spark.network.crypto.config.*None - Configuration values for the commons-crypto library, such as which cipher implementations to - use. The config name should be the name of commons-crypto configuration without the - "commons.crypto" prefix. -
    spark.authenticate.enableSaslEncryptionfalse - Enable encrypted communication when authentication is - enabled. This is supported by the block transfer service and the - RPC endpoints. -
    spark.network.sasl.serverAlwaysEncryptfalse - Disable unencrypted connections for services that support SASL authentication. -
    spark.core.connection.ack.wait.timeoutspark.network.timeout - How long for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. -
    spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). Putting a "*" in - the list means any user can have access to modify it. -
    spark.modify.acls.groupsEmpty - Comma separated list of groups that have modify access to the Spark job. This can be used if you - have a set of administrators or developers from the same team to have access to control the job. - Putting a "*" in the list means any user in any group has the access to modify the Spark job. - The user groups are obtained from the instance of the groups mapping provider specified by - spark.user.groups.mapping. Check the entry spark.user.groups.mapping - for more details. -
    spark.ui.filtersNone - Comma separated list of filter class names to apply to the Spark web UI. The filter should be a - standard - javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of:
    - spark.<class name of filter>.params='param1=value1,param2=value2'
    - For example:
    - -Dspark.ui.filters=com.test.filter1
    - -Dspark.com.test.filter1.params='param1=foo,param2=testing' -
    spark.ui.view.aclsEmpty - Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. Putting a "*" in the list means any user can - have view access to this Spark job. -
    spark.ui.view.acls.groupsEmpty - Comma separated list of groups that have view access to the Spark web ui to view the Spark Job - details. This can be used if you have a set of administrators or developers or users who can - monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view - the Spark job details on the Spark web ui. The user groups are obtained from the instance of the - groups mapping provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    - -### TLS / SSL - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.ssl.enabledfalse - Whether to enable SSL connections on all supported protocols. - -
    When spark.ssl.enabled is configured, spark.ssl.protocol - is required. - -
    All the SSL settings like spark.ssl.xxx where xxx is a - particular configuration property, denote the global configuration for all the supported - protocols. In order to override the global configuration for the particular protocol, - the properties must be overwritten in the protocol-specific namespace. - -
    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Example values for YYY - include fs, ui, standalone, and - historyServer. See SSL - Configuration for details on hierarchical SSL configuration for services. -
    spark.ssl.[namespace].portNone - The port where the SSL service will listen on. - -
    The port must be defined within a namespace configuration; see - SSL Configuration for the available - namespaces. - -
    When not set, the SSL port will be derived from the non-SSL port for the - same service. A value of "0" will make the service bind to an ephemeral port. -
    spark.ssl.enabledAlgorithmsEmpty - A comma separated list of ciphers. The specified ciphers must be supported by JVM. - The reference list of protocols one can find on - this - page. - Note: If not set, it will use the default cipher suites of JVM. -
    spark.ssl.keyPasswordNone - A password to the private key in key-store. -
    spark.ssl.keyStoreNone - A path to a key-store file. The path can be absolute or relative to the directory where - the component is started in. -
    spark.ssl.keyStorePasswordNone - A password to the key-store. -
    spark.ssl.keyStoreTypeJKS - The type of the key-store. -
    spark.ssl.protocolNone - A protocol name. The protocol must be supported by JVM. The reference list of protocols - one can find on this - page. -
    spark.ssl.needClientAuthfalse - Set true if SSL needs client authentication. -
    spark.ssl.trustStoreNone - A path to a trust-store file. The path can be absolute or relative to the directory - where the component is started in. -
    spark.ssl.trustStorePasswordNone - A password to the trust-store. -
    spark.ssl.trustStoreTypeJKS - The type of the trust-store. -
    - +Please refer to the [Security](security.html) page for available options on how to secure different +Spark subsystems. ### Spark SQL diff --git a/docs/monitoring.md b/docs/monitoring.md index d5f7ffcc260a1..01736c77b0979 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -80,7 +80,10 @@ The history server can be configured as follows: -### Spark configuration options +### Spark History Server Configuration Options + +Security options for the Spark History Server are covered more detail in the +[Security](security.html#web-ui) page. @@ -160,41 +163,6 @@ The history server can be configured as follows: Location of the kerberos keytab file for the History Server. - - - - - - - - - - - - - - - diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c010af35f8d2e..e07759a4dba87 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -2,6 +2,8 @@ layout: global title: Running Spark on YARN --- +* This will become a table of contents (this text will be scraped). +{:toc} Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) @@ -217,8 +219,8 @@ To use a custom metrics.properties for the application master and executors, upd @@ -265,19 +267,6 @@ To use a custom metrics.properties for the application master and executors, upd distribution. - - - - - @@ -373,31 +362,6 @@ To use a custom metrics.properties for the application master and executors, upd in YARN ApplicationReports, which can be used for filtering when querying YARN apps. - - - - - - - - - - - - - - - @@ -424,17 +388,6 @@ To use a custom metrics.properties for the application master and executors, upd See spark.yarn.config.gatewayPath. - - - - - @@ -468,48 +421,104 @@ To use a custom metrics.properties for the application master and executors, upd - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. -# Running in a Secure Cluster +# Kerberos + +Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. + +In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home +directory, Spark will also automatically obtain delegation tokens for the service hosting the +staging directory of the Spark application. + +If an application needs to interact with other secure Hadoop filesystems, their URIs need to be +explicitly provided to Spark at launch time. This is done by listing them in the +`spark.yarn.access.hadoopFileSystems` property, described in the configuration section below. -As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to -authenticate principals associated with services and clients. This allows clients to -make requests of these authenticated services; the services to grant rights -to the authenticated principals. +The YARN integration also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. These +providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to +`false`, where `{service}` is the name of the credential provider. + +## YARN-specific Kerberos Configuration + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse - Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had - set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via - spark.ui.view.acls and groups specified via spark.ui.view.acls.groups - when the application was run will also have authorization to view that application. - If disabled, no access control checks are made. -
    spark.history.ui.admin.aclsempty - Comma separated list of users/administrators that have view access to all the Spark applications in - history server. By default only the users permitted to view the application at run-time could - access the related application history, with this, configured users/administrators could also - have the permission to access it. - Putting a "*" in the list means any user can have the privilege of admin. -
    spark.history.ui.admin.acls.groupsempty - Comma separated list of groups that have view access to all the Spark applications in - history server. By default only the groups permitted to view the application at run-time could - access the related application history, with this, configured groups could also - have the permission to access it. - Putting a "*" in the list means any group can have the privilege of admin. -
    spark.history.fs.cleaner.enabled falsespark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to - being added to YARN's distributed cache. For use in cases where the YARN service does not + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not support schemes that are supported by Spark, like http, https and ftp.
    spark.yarn.access.hadoopFileSystems(none) - A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For - example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, - webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed - and Kerberos must be properly configured to be able to access them (either in the same realm - or in a trusted realm). Spark acquires security tokens for each of the filesystems so that - the Spark application can access those remote Hadoop filesystems. spark.yarn.access.namenodes - is deprecated, please use this instead. -
    spark.yarn.appMasterEnv.[EnvironmentVariableName] (none)
    spark.yarn.keytab(none) - The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) -
    spark.yarn.principal(none) - Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) -
    spark.yarn.kerberos.relogin.period1m - How often to check whether the kerberos TGT should be renewed. This should be set to a value - that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). - The default value should be enough for most deployments. -
    spark.yarn.config.gatewayPath (none)
    spark.security.credentials.${service}.enabledtrue - Controls whether to obtain credentials for services when security is enabled. - By default, credentials for all supported services are retrieved when those services are - configured, but it's possible to disable that behavior if it somehow conflicts with the - application being run. For further details please see - [Running in a Secure Cluster](running-on-yarn.html#running-in-a-secure-cluster) -
    spark.yarn.rolledLog.includePattern (none)
    + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. This keytab + will be copied to the node running the YARN Application Master via the YARN Distributed Cache, and + will be used for renewing the login tickets and the delegation tokens periodically. Equivalent to + the --keytab command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure clusters. Equivalent to the + --principal command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.access.hadoopFileSystems(none) + A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For + example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the filesystems so that + the Spark application can access those remote Hadoop filesystems. +
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    -Hadoop services issue *hadoop tokens* to grant access to the services and data. -Clients must first acquire tokens for the services they will access and pass them along with their -application as it is launched in the YARN cluster. +## Troubleshooting Kerberos -For a Spark application to interact with any of the Hadoop filesystem (for example hdfs, webhdfs, etc), HBase and Hive, it must acquire the relevant tokens -using the Kerberos credentials of the user launching the application -—that is, the principal whose identity will become that of the launched Spark application. +Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to +enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` +environment variable. -This is normally done at launch time: in a secure cluster Spark will automatically obtain a -token for the cluster's default Hadoop filesystem, and potentially for HBase and Hive. +```bash +export HADOOP_JAAS_DEBUG=true +``` -An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares -the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.security.credentials.hbase.enabled` is not set to `false`. +The JDK classes can be configured to enable extra logging of their Kerberos and +SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` +and `sun.security.spnego.debug=true` -Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration -includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.security.credentials.hive.enabled` is not set to `false`. +``` +-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` -If an application needs to interact with other secure Hadoop filesystems, then -the tokens needed to access these clusters must be explicitly requested at -launch time. This is done by listing them in the `spark.yarn.access.hadoopFileSystems` property. +All these options can be enabled in the Application Master: ``` -spark.yarn.access.hadoopFileSystems hdfs://ireland.example.org:8020/,webhdfs://frankfurt.example.org:50070/ +spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true +spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true ``` -Spark supports integrating with other security-aware services through Java Services mechanism (see -`java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` -should be available to Spark by listing their names in the corresponding file in the jar's -`META-INF/services` directory. These plug-ins can be disabled by setting -`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of -credential provider. +Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log +will include a list of all tokens obtained, and their expiry details -## Configuring the External Shuffle Service + +# Configuring the External Shuffle Service To start the Spark Shuffle Service on each `NodeManager` in your YARN cluster, follow these instructions: @@ -542,7 +551,7 @@ The following extra configuration options are available when the shuffle service -## Launching your application with Apache Oozie +# Launching your application with Apache Oozie Apache Oozie can launch Spark applications as part of a workflow. In a secure cluster, the launched application will need the relevant tokens to access the cluster's @@ -576,35 +585,7 @@ spark.security.credentials.hbase.enabled false The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. -## Troubleshooting Kerberos - -Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to -enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` -environment variable. - -```bash -export HADOOP_JAAS_DEBUG=true -``` - -The JDK classes can be configured to enable extra logging of their Kerberos and -SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` -and `sun.security.spnego.debug=true` - -``` --Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -All these options can be enabled in the Application Master: - -``` -spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true -spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log -will include a list of all tokens obtained, and their expiry details - -## Using the Spark History Server to replace the Spark Web UI +# Using the Spark History Server to replace the Spark Web UI It is possible to use the Spark History Server application page as the tracking URL for running applications when the application UI is disabled. This may be desirable on secure clusters, or to diff --git a/docs/security.md b/docs/security.md index 913d9df50eb1c..3e5607a9a0d67 100644 --- a/docs/security.md +++ b/docs/security.md @@ -3,47 +3,336 @@ layout: global displayTitle: Spark Security title: Security --- +* This will become a table of contents (this text will be scraped). +{:toc} -Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: +# Spark RPC -* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. -* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. +## Authentication -## Web UI +Spark currently supports authentication for RPC channels using a shared secret. Authentication can +be turned on by setting the `spark.authenticate` configuration parameter. -The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). +The exact mechanism used to generate and distribute the shared secret is deployment-specific. -### Authentication +For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle +generating and distributing the shared secret. Each application will use a unique shared secret. In +the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of +secrets to be secure. -A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. +This secret will be shared by all the daemons and applications, so this deployment configuration is +not as secure as the above, especially when considering multi-tenant clusters. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. -Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.authenticatefalseWhether Spark authenticates its internal connections.
    spark.authenticate.secretNone + The secret key used authentication. See above for when this configuration should be set. +
    + +## Encryption -## Event Logging +Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC +authentication must also be enabled and properly configured. AES encryption uses the +[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +configuration system allows access to that library's configuration for advanced users. -If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +There is also support for SASL-based encryption, although it should be considered deprecated. It +is still required when talking to shuffle services from Spark versions older than 2.2.0. -## Encryption +The following table describes the different options available for configuring this feature. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.network.crypto.enabledfalse + Enable AES-based RPC encryption, including the new authentication protocol added in 2.2.0. +
    spark.network.crypto.keyLength128 + The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. +
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 + The key factory algorithm to use when generating encryption keys. Should be one of the + algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. +
    spark.network.crypto.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    spark.network.crypto.saslFallbacktrue + Whether to fall back to SASL authentication if authentication fails using Spark's internal + mechanism. This is useful when the application is connecting to old shuffle services that + do not support the internal Spark authentication protocol. On the shuffle service side, + disabling this feature will block older clients from authenticating. +
    spark.authenticate.enableSaslEncryptionfalse + Enable SASL-based encrypted communication. +
    spark.network.sasl.serverAlwaysEncryptfalse + Disable unencrypted connections for ports using SASL authentication. This will deny connections + from clients that have authentication enabled, but do not request SASL-based encryption. +
    + + +# Local Storage Encryption + +Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle +spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover +encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or +`saveAsTable`. + +The following settings cover enabling encryption for data written to disk: + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.io.encryption.enabledfalse + Enable local disk I/O encryption. Currently supported by all modes except Mesos. It's strongly + recommended that RPC encryption be enabled when using this feature. +
    spark.io.encryption.keySizeBits128 + IO encryption key size in bits. Supported values are 128, 192 and 256. +
    spark.io.encryption.keygen.algorithmHmacSHA1 + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. +
    spark.io.encryption.commons.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    + + +# Web UI + +## Authentication and Authorization + +Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +You will need a filter that implements the authentication method you want to deploy. Spark does not +provide any built-in authentication filters. + +Spark also supports access control to the UI when an authentication filter is present. Each +application can be configured with its own separate access control lists (ACLs). Spark +differentiates between "view" permissions (who is allowed to see the application's UI), and "modify" +permissions (who can do things like kill jobs in a running application). + +ACLs can be configured for either users or groups. Configuration entries accept comma-separated +lists as input, meaning multiple users or groups can be given the desired privileges. This can be +used if you run on a shared cluster and have a set of administrators or developers who need to +monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL +means that all users will have the respective pivilege. By default, only the user submitting the +application is added to the ACLs. + +Group membership is established by using a configurable group mapping provider. The mapper is +configured using the spark.user.groups.mapping config option, described in the table +below. + +The following options control the authentication of Web UIs: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.filtersNone + See the Spark UI configuration for how to configure + filters. +
    spark.acls.enablefalse + Whether UI ACLs should be enabled. If enabled, this checks to see if the user has access + permissions to view or modify the application. Note this requires the user to be authenticated, + so if no authentication filter is installed, this option does not do anything. +
    spark.admin.aclsNone + Comma-separated list of users that have view and modify access to the Spark application. +
    spark.admin.acls.groupsNone + Comma-separated list of groups that have view and modify access to the Spark application. +
    spark.modify.aclsNone + Comma-separated list of users that have modify access to the Spark application. +
    spark.modify.acls.groupsNone + Comma-separated list of groups that have modify access to the Spark application. +
    spark.ui.view.aclsNone + Comma-separated list of users that have view access to the Spark application. +
    spark.ui.view.acls.groupsNone + Comma-separated list of groups that have view access to the Spark application. +
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider, which can be configured by + this property. + +
    By default, a Unix shell-based implementation is used, which collects this information + from the host OS. + +
    Note: This implementation supports only Unix/Linux-based environments. + Windows environment is currently not supported. However, a new platform/protocol can + be supported by implementing the trait mentioned above. +
    + +On YARN, the view and modify ACLs are provided to the YARN service when submitting applications, and +control who has the respective privileges via YARN interfaces. + +## Spark History Server ACLs -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service -and the RPC endpoints. Shuffle files can also be encrypted if desired. +Authentication for the SHS Web UI is enabled the same way as for regular applications, using +servlet filters. -### SSL Configuration +To enable authorization in the SHS, a few extra options are used: + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse + Specifies whether ACLs should be checked to authorize users viewing the applications in + the history server. If enabled, access control checks are performed regardless of what the + individual applications had set for spark.ui.acls.enable. The application owner + will always have authorization to view their own application and any users specified via + spark.ui.view.acls and groups specified via spark.ui.view.acls.groups + when the application was run will also have authorization to view that application. + If disabled, no access control checks are made for any application UIs available through + the history server. +
    spark.history.ui.admin.aclsNone + Comma separated list of users that have view access to all the Spark applications in history + server. +
    spark.history.ui.admin.acls.groupsNone + Comma separated list of groups that have view access to all the Spark applications in history + server. +
    + +The SHS uses the same options to configure the group mapping provider as regular applications. +In this case, the group mapping provider will apply to all UIs server by the SHS, and individual +application configurations will be ignored. + +## SSL Configuration Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the -protocols without disabling the ability to configure each one individually. The common SSL settings -are at `spark.ssl` namespace in Spark configuration. The following table describes the -component-specific configuration namespaces used to override the default settings: +protocols without disabling the ability to configure each one individually. The following table +describes the the SSL configuration namespaces: + + + + @@ -58,49 +347,205 @@ component-specific configuration namespaces used to override the default setting
    Config Namespace Component
    spark.ssl + The default SSL configuration. These values will apply to all namespaces below, unless + explicitly overridden at the namespace level. +
    spark.ssl.ui Spark application Web UI
    -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). -SSL must be configured on each node and configured for each component involved in communication using the particular protocol. +The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be +replaced with one of the above namespaces. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    ${ns}.enabledfalseEnables SSL. When enabled, ${ns}.ssl.protocol is required.
    ${ns}.portNone + The port where the SSL service will listen on. + +
    The port must be defined within a specific namespace configuration. The default + namespace is ignored when reading this configuration. + +
    When not set, the SSL port will be derived from the non-SSL port for the + same service. A value of "0" will make the service bind to an ephemeral port. +
    ${ns}.enabledAlgorithmsNone + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + +
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section + of the Java security guide. The list for Java 8 can be found at + this + page. + +
    Note: If not set, the default cipher suite for the JRE will be used. +
    ${ns}.keyPasswordNone + The password to the private key in the key store. +
    ${ns}.keyStoreNone + Path to the key store file. The path can be absolute or relative to the directory in which the + process is started. +
    ${ns}.keyStorePasswordNonePassword to the key store.
    ${ns}.keyStoreTypeJKSThe type of the key store.
    ${ns}.protocolNone + TLS protocol to use. The protocol must be supported by JVM. + +
    The reference list of protocols can be found in the "Additional JSSE Standard Names" + section of the Java security guide. For Java 8, the list can be found at + this + page. +
    ${ns}.needClientAuthfalseWhether to require client authentication.
    ${ns}.trustStoreNone + Path to the trust store file. The path can be absolute or relative to the directory in which + the process is started. +
    ${ns}.trustStorePasswordNonePassword for the trust store.
    ${ns}.trustStoreTypeJKSThe type of the trust store.
    + +## Preparing the key stores + +Key stores can be generated by `keytool` program. The reference documentation for this tool for +Java 8 is [here](https://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html). +The most basic steps to configure the key stores and the trust store for a Spark Standalone +deployment mode is as follows: + +* Generate a key pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust store +* Distribute the trust store to the cluster nodes ### YARN mode -The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. -For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. +To provide a local trust store or key store file to drivers running in cluster mode, they can be +distributed with the application using the `--files` command line argument (or the equivalent +`spark.files` configuration). The files will be placed on the driver's working directory, so the TLS +configuration should just reference the file name with no absolute path. + +Distributing local key stores this way may require the files to be staged in HDFS (or other similar +distributed file system used by the cluster), so it's recommended that the undelying file system be +configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode -The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +The user needs to provide key stores and configuration options for master and workers. They have to +be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in +`SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. + +The user may allow the executors to use the SSL settings inherited from the worker process. That +can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that case, the settings +provided by the user on the client side are not used. ### Mesos mode -Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables. +Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based +secrets. Spark allows the specification of file-based and environment variable based secrets with +`spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. -### Preparing the key-stores -Key-stores can be generated by `keytool` program. The reference documentation for this tool is -[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic -steps to configure the key-stores and the trust-store for the standalone deployment mode is as -follows: +Depending on the secret store backend secrets can be passed by reference or by value with the +`spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, +respectively. -* Generate a keys pair for each node -* Export the public key of the key pair to a file on each node -* Import all exported public keys into a single trust-store -* Distribute the trust-store over the nodes +Reference type secrets are served by the secret store and referred to by name, for example +`/mysecret`. Value type secrets are passed on the command line and translated into their +appropriate files or environment variables. -### Configuring SASL Encryption +## HTTP Security Headers -SASL encryption is currently supported for the block transfer service when authentication -(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set -`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. +Apache Spark can be configured to include HTTP headers to aid in preventing Cross Site Scripting +(XSS), Cross-Frame Scripting (XFS), MIME-Sniffing, and also to enforce HTTP Strict Transport +Security. -When using an external shuffle service, it's possible to disable unencrypted connections by setting -`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that -option is enabled, applications that are not set up to use SASL encryption will fail to connect to -the shuffle service. + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When enabled, X-Content-Type-Options HTTP response header will be set to "nosniff". +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly. This option is only used when + SSL/TLS is enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    -## Configuring Ports for Network Security + +# Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight firewall settings. Below are the primary ports that Spark uses for its communication and how to configure those ports. -### Standalone mode only +## Standalone mode only @@ -141,7 +586,7 @@ configure those ports.
    -### All cluster managers +## All cluster managers @@ -182,54 +627,70 @@ configure those ports.
    -### HTTP Security Headers -Apache Spark can be configured to include HTTP Headers which aids in preventing Cross -Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP -Strict Transport Security. +# Kerberos + +Spark supports submitting applications in environments that use Kerberos for authentication. +In most cases, Spark relies on the credentials of the current logged in user when authenticating +to Kerberos-aware services. Such credentials can be obtained by logging in to the configured KDC +with tools like `kinit`. + +When talking to Hadoop-based services, Spark needs to obtain delegation tokens so that non-local +processes can authenticate. Spark ships with support for HDFS and other Hadoop file systems, Hive +and HBase. + +When using a Hadoop filesystem (such HDFS or WebHDFS), Spark will acquire the relevant tokens +for the service hosting the user's home directory. + +An HBase token will be obtained if HBase is in the application's classpath, and the HBase +configuration has Kerberos authentication turned (`hbase.security.authentication=kerberos`). + +Similarly, a Hive token will be obtained if Hive is in the classpath, and the configuration includes +URIs for remote metastore services (`hive.metastore.uris` is not empty). + +Delegation token support is currently only supported in YARN and Mesos modes. Consult the +deployment-specific page for more information. + +The following options provides finer-grained control for this feature: - - - - - - + - - - - -
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block - Value for HTTP X-XSS-Protection response header. You can choose appropriate value - from below: -
      -
    • 0 (Disables XSS filtering)
    • -
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, - the browser will sanitize the page.)
    • -
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering - of the page if an attack is detected.)
    • -
    -
    spark.ui.xContentTypeOptions.enabledspark.security.credentials.${service}.enabled true - When value is set to "true", X-Content-Type-Options HTTP response header will be set - to "nosniff". Set "false" to disable. -
    spark.ui.strictTransportSecurityNone - Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate - value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. -
      -
    • max-age=<expire-time>
    • -
    • max-age=<expire-time>; includeSubDomains
    • -
    • max-age=<expire-time>; preload
    • -
    + Controls whether to obtain credentials for services when security is enabled. + By default, credentials for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run.
    - -See the [configuration page](configuration.html) for more details on the security configuration -parameters, and -org.apache.spark.SecurityManager for implementation details about security. +## Long-Running Applications + +Long-running applications may run into issues if their run time exceeds the maximum delegation +token lifetime configured in services it needs to access. + +Spark supports automatically creating new tokens for these applications when running in YARN mode. +Kerberos credentials need to be provided to the Spark application via the `spark-submit` command, +using the `--principal` and `--keytab` parameters. + +The provided keytab will be copied over to the machine running the Application Master via the Hadoop +Distributed Cache. For this reason, it's strongly recommended that both YARN and HDFS be secured +with encryption, at least. + +The Kerberos login will be periodically renewed using the provided credentials, and new delegation +tokens for supported will be created. + + +# Event Logging + +If your applications are using event logging, the directory where the event logs go +(`spark.eventLog.dir`) should be manually created with proper permissions. To secure the log files, +the directory permissions should be set to `drwxrwxrwxt`. The owner and group of the directory +should correspond to the super user who is running the Spark History Server. +This will allow all users to write to the directory but will prevent unprivileged users from +reading, removing or renaming a file unless they own it. The event log files will be created by +Spark with permissions such that only the user and group have read and write access. From 3e778f5a91b0553b09fe0e0ee84d771a71504960 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 26 Mar 2018 15:45:27 -0700 Subject: [PATCH 217/593] [SPARK-23162][PYSPARK][ML] Add r2adj into Python API in LinearRegressionSummary ## What changes were proposed in this pull request? Adding r2adj in LinearRegressionSummary for Python API. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes in tests.py. Author: Kevin Yu Closes #20842 from kevinyu98/spark-23162. --- python/pyspark/ml/regression.py | 18 ++++++++++++++++-- python/pyspark/ml/tests.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de0a0fa9f3bf8..9a66d87d7f211 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -336,10 +336,10 @@ def rootMeanSquaredError(self): @since("2.0.0") def r2(self): """ - Returns R^2^, the coefficient of determination. + Returns R^2, the coefficient of determination. .. seealso:: `Wikipedia coefficient of determination \ - ` + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -347,6 +347,20 @@ def r2(self): """ return self._call_java("r2") + @property + @since("2.4.0") + def r2adj(self): + """ + Returns Adjusted R^2, the adjusted coefficient of determination. + + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ + `_ + + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark versions. + """ + return self._call_java("r2adj") + @property @since("2.0.0") def residuals(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index cf1ffa181ecec..6b4376cbf14e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1559,6 +1559,7 @@ def test_linear_regression_summary(self): self.assertAlmostEqual(s.meanSquaredError, 0.0) self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) self.assertTrue(isinstance(s.residuals, DataFrame)) self.assertEqual(s.numInstances, 2) self.assertEqual(s.degreesOfFreedom, 1) From 35997b59f3116830af06b3d40a7675ef0dbf7091 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Mar 2018 14:49:50 +0200 Subject: [PATCH 218/593] [SPARK-23794][SQL] Make UUID as stateful expression ## What changes were proposed in this pull request? The UUID() expression is stateful and should implement the `Stateful` trait instead of the `Nondeterministic` trait. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20912 from viirya/SPARK-23794. --- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 4 +++- .../sql/catalyst/expressions/MiscExpressionsSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ec93620038cff..a390f8ef7fd9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -123,7 +123,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { def this() = this(None) @@ -152,4 +152,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = "false") } + + override def freshCopy(): Uuid = Uuid(randomSeed) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 3383d421f5616..b6c269348b002 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -59,6 +59,12 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithGeneratedMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) + + val uuid = Uuid(seed1) + assert(uuid.fastEquals(uuid)) + assert(!uuid.fastEquals(Uuid(seed1))) + assert(!uuid.fastEquals(uuid.freshCopy())) + assert(!uuid.fastEquals(Uuid(seed2))) } test("PrintToStderr") { From c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 27 Mar 2018 14:39:05 -0700 Subject: [PATCH 219/593] [SPARK-23096][SS] Migrate rate source to V2 ## What changes were proposed in this pull request? This PR migrate micro batch rate source to V2 API and rewrite UTs to suite V2 test. ## How was this patch tested? UTs. Author: jerryshao Closes #20688 from jerryshao/SPARK-23096. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------- .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++ .../sources/RateStreamSourceV2.scala | 187 ---------- .../execution/streaming/RateSourceSuite.scala | 194 ---------- .../streaming/RateSourceV2Suite.scala | 191 ---------- .../sources/RateStreamProviderSuite.scala | 344 ++++++++++++++++++ 10 files changed, 715 insertions(+), 844 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala deleted file mode 100644 index 03d0f63fa4d7f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) - } - } - - test("basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("valueAtSecond") { - import RateStreamSource._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[ArithmeticException](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) - for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala new file mode 100644 index 0000000000000..9149e50962255 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.nio.file.Files +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + SparkSession.setActiveSession(spark) + } + + override def afterAll(): Unit = { + SparkSession.clearActiveSession() + super.afterAll() + } + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("microbatch - basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("valueAtSecond") { + import RateStreamProvider._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[TreeNodeException[_]](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getCause.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[IllegalArgumentException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + for (msg <- expectedMessages) { + assert(e.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} From ed72badb04a56d8046bbd185245abf5ae265ccfd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 27 Mar 2018 20:06:12 -0700 Subject: [PATCH 220/593] [SPARK-23699][PYTHON][SQL] Raise same type of error caught with Arrow enabled ## What changes were proposed in this pull request? When using Arrow for createDataFrame or toPandas and an error is encountered with fallback disabled, this will raise the same type of error instead of a RuntimeError. This change also allows for the traceback of the error to be retained and prevents the accidental chaining of exceptions with Python 3. ## How was this patch tested? Updated existing tests to verify error type. Author: Bryan Cutler Closes #20839 from BryanCutler/arrow-raise-same-error-SPARK-23699. --- python/pyspark/sql/dataframe.py | 25 +++++++++++++------------ python/pyspark/sql/session.py | 13 +++++++------ python/pyspark/sql/tests.py | 10 +++++----- python/pyspark/sql/utils.py | 6 ++++++ 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3fc194d8ec1d1..16f8e52dead7b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2007,7 +2007,7 @@ def toPandas(self): "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) @@ -2015,11 +2015,12 @@ def toPandas(self): else: msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Try to use Arrow optimization when the schema is supported and the required version # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. @@ -2042,12 +2043,12 @@ def toPandas(self): # be executed. So, simply fail in this case for now. msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed unexpectedly:\n %s\n" - "Note that 'spark.sql.execution.arrow.fallback.enabled' does " - "not have an effect in such failure in the middle of " - "computation." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and can not continue. Note that " + "'spark.sql.execution.arrow.fallback.enabled' does not have an effect " + "on failures in the middle of computation.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e82a9750a0014..13d6e2e53dbd0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -674,18 +674,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 967cc83166f3f..01c5dd6ff8c3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3559,7 +3559,7 @@ def test_toPandas_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): @@ -3682,7 +3682,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3707,7 +3707,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3775,14 +3775,14 @@ def test_createDataFrame_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type'): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 578298632dd4c..45363f089a73d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -121,7 +121,10 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion try: import pandas + have_pandas = True except ImportError: + have_pandas = False + if not have_pandas: raise ImportError("Pandas >= %s must be installed; however, " "it was not found." % minimum_pandas_version) if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): @@ -138,7 +141,10 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion try: import pyarrow + have_arrow = True except ImportError: + have_arrow = False + if not have_arrow: raise ImportError("PyArrow >= %s must be installed; however, " "it was not found." % minimum_pyarrow_version) if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): From 34c4b9c57e114cdb390e4dbc7383284d82fea317 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Mar 2018 19:49:27 +0800 Subject: [PATCH 221/593] [SPARK-23765][SQL] Supports custom line separator for json datasource ## What changes were proposed in this pull request? This PR proposes to add lineSep option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. The approach is similar with https://github.com/apache/spark/pull/20727; however, one main difference is, it uses text datasource's `lineSep` option to parse line by line in JSON's schema inference. ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Author: hyukjinkwon Closes #20877 from HyukjinKwon/linesep-json. --- python/pyspark/sql/readwriter.py | 14 ++-- python/pyspark/sql/streaming.py | 6 +- python/pyspark/sql/tests.py | 17 +++++ .../spark/sql/catalyst/json/JSONOptions.scala | 11 ++++ .../sql/catalyst/json/JacksonGenerator.scala | 8 ++- .../apache/spark/sql/DataFrameReader.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/json/JsonDataSource.scala | 17 +++-- .../datasources/text/TextOptions.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 66 ++++++++++++++++++- .../datasources/text/TextSuite.scala | 4 +- 12 files changed, 136 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4f9b9383a5ef4..6bd79bc2f43e5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -746,7 +748,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, + lineSep=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -770,12 +773,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, + lineSep=lineSep) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index c7907aaaf1f7b..15f9407389864 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,7 +405,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -468,6 +468,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -482,7 +484,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 01c5dd6ff8c3f..5181053a0d318 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -676,6 +676,23 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478a..5c9adc3332bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.json +import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -85,6 +86,16 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index eb06e4f304f0a..9c413de752a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer +import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator( private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val lineSeparator: String = options.lineSeparatorInWrite + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -251,5 +254,8 @@ private[sql] class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } - def writeLineEnding(): Unit = gen.writeRaw('\n') + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1a5e47508c070..ae3ba1690f696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -366,6 +366,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bb93889dc55e9..bbc063148a72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,6 +518,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 77e7edc8e7a20..5769c09c9a1d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,7 +92,8 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + val json: Dataset[String] = createBaseDataset( + sparkSession, inputPaths, parsedOptions.lineSeparator) inferFromDataset(json, parsedOptions) } @@ -104,13 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { private def createBaseDataset( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): Dataset[String] = { + inputPaths: Seq[FileStatus], + lineSeparator: Option[String]): Dataset[String] = { + val textOptions = lineSeparator.map { lineSep => + Map(TextOptions.LINE_SEPARATOR -> lineSep) + }.getOrElse(Map.empty[String, String]) + val paths = inputPaths.map(_.getPath.toString) sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -120,7 +127,7 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val safeParser = new FailureSafeParser[Text]( input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 18698df9fd8e5..5c1a35434f7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -52,7 +52,7 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[text] object TextOptions { +private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" val LINE_SEPARATOR = "lineSep" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9b17406a816b5..ae93965bc50ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -268,6 +268,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115a..10bac0554484a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -27,7 +28,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} @@ -2063,4 +2064,67 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val data = + s""" + | {"f": + |"a", "f0": 1}$lineSep{"f": + | + |"c", "f0": 2}$lineSep{"f": "d", "f0": 3} + """.stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write.option("lineSep", lineSep).json(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert( + readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).json(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => + testLineSeparator(lineSep) + } + // scalastyle:on nonascii + + test("""SPARK-21289: Support line separator - default value \r, \r\n and \n""") { + val data = + "{\"f\": \"a\", \"f0\": 1}\r{\"f\": \"c\", \"f0\": 2}\r\n{\"f\": \"d\", \"f0\": 3}\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index e8a5299d6ba9d..0e7f3afa9c3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -208,9 +208,11 @@ class TextSuite extends QueryTest with SharedSQLContext { } } - Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => testLineSeparator(lineSep) } + // scalastyle:on nonascii private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString From 761565a3ccbf7f083e587fee14a27b61867a3886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 28 Mar 2018 09:11:52 -0700 Subject: [PATCH 222/593] Revert "[SPARK-23096][SS] Migrate rate source to V2" This reverts commit c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 +++++++++++++ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 ----------- .../sources/RateStreamProvider.scala | 125 ------- .../sources/RateStreamSourceV2.scala | 187 ++++++++++ .../execution/streaming/RateSourceSuite.scala | 194 ++++++++++ .../streaming/RateSourceV2Suite.scala | 191 ++++++++++ .../sources/RateStreamProviderSuite.scala | 344 ------------------ 10 files changed, 844 insertions(+), 715 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1b37905543b4e..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,5 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,7 +566,6 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName - val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -588,8 +587,7 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, - "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..649fbbfa184ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister + with DataSourceV2 with ContinuousReadSupport { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + if (schema.nonEmpty) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + (shortName(), RateSourceProvider.SCHEMA) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + new RateStreamContinuousReader(options) + } + + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA + override def readSchema(): StructType = RateSourceProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,19 +98,6 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} - private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } - } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala deleted file mode 100644 index 6cf8520fc544f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ManualClock, SystemClock} - -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { - import RateStreamProvider._ - - private[sources] val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock - } - - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private[sources] val creationTimeMs = { - val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) - require(session.isDefined) - - val metadataLog = - new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - @volatile private var lastTimeMs: Long = creationTimeMs - - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA - - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return List.empty.asJava - } - - val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } - - (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( - p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" -} - -class RateStreamMicroBatchDataReaderFactory( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { - - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) -} - -class RateStreamMicroBatchDataReader( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { - private var count = 0 - - override def next(): Boolean = { - rangeStart + partitionId + numPartitions * count < rangeEnd - } - - override def get(): Row = { - val currValue = rangeStart + partitionId + numPartitions * count - count += 1 - val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) - } - - override def close(): Unit = {} -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala deleted file mode 100644 index 6bdd492f0cb35..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} -import org.apache.spark.sql.types._ - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { - import RateStreamProvider._ - - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - if (options.get(ROWS_PER_SECOND).isPresent) { - val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") - } - } - - if (options.get(RAMP_UP_TIME).isPresent) { - val rampUpTimeSeconds = - JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") - } - } - - if (options.get(NUM_PARTITIONS).isPresent) { - val numPartitions = options.get(NUM_PARTITIONS).get().toInt - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") - } - } - - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) - - override def shortName(): String = "rate" -} - -object RateStreamProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 - - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - val RAMP_UP_TIME = "rampUpTime" - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala new file mode 100644 index 0000000000000..4e2459bb05bd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceOptions) + extends MicroBatchReader { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } + + private val numPartitions = + options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + private val rowsPerSecond = + options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + + // The interval (in milliseconds) between rows in each partition. + // e.g. if there are 4 global rows per second, and 2 partitions, each partition + // should output rows every (1000 * 2 / 4) = 500 ms. + private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond + + override def readSchema(): StructType = { + StructType( + StructField("timestamp", TimestampType, false) :: + StructField("value", LongType, false) :: Nil) + } + + val creationTimeMs = clock.getTimeMillis() + + private var start: RateStreamOffset = _ + private var end: RateStreamOffset = _ + + override def setOffsetRange( + start: Optional[Offset], + end: Optional[Offset]): Unit = { + this.start = start.orElse( + RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) + .asInstanceOf[RateStreamOffset] + + this.end = end.orElse { + val currentTime = clock.getTimeMillis() + RateStreamOffset( + this.start.partitionToValueAndRunTimeMs.map { + case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + // Calculate the number of rows we should advance in this partition (based on the + // current time), and output a corresponding offset. + val readInterval = currentTime - currentReadTime + val numNewRows = readInterval / msPerPartitionBetweenRows + if (numNewRows <= 0) { + startOffset + } else { + (part, ValueRunTimeMsPair( + currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + } + } + ) + }.asInstanceOf[RateStreamOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startMap = start.partitionToValueAndRunTimeMs + val endMap = end.partitionToValueAndRunTimeMs + endMap.keys.toSeq.map { part => + val ValueRunTimeMsPair(endVal, _) = endMap(part) + val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) + + val packedRows = mutable.ListBuffer[(Long, Long)]() + var outVal = startVal + numPartitions + var outTimeMs = startTimeMs + while (outVal <= endVal) { + packedRows.append((outTimeMs, outVal)) + outVal += numPartitions + outTimeMs += msPerPartitionBetweenRows + } + + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} +} + +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) +} + +class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the seq. + currentIndex += 1 + currentIndex < vals.size + } + + override def get(): Row = { + Row( + DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), + vals(currentIndex)._2) + } + + override def close(): Unit = {} +} + +object RateStreamSourceV2 { + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + + private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..03d0f63fa4d7f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala new file mode 100644 index 0000000000000..983ba1668f58f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - numPartitions propagated") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + } + + test("microbatch - set offset") { + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: RateStreamOffset => + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: RateStreamOffset => + // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted + // longer than 100ms. It should never be early. + assert(r.partitionToValueAndRunTimeMs(0).value >= 9) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) + + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) + } + + test("microbatch - data read") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) + val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { + case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) + }.toMap) + + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala deleted file mode 100644 index 9149e50962255..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.nio.file.Files -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.Offset -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - protected override def beforeAll(): Unit = { - super.beforeAll() - SparkSession.setActiveSession(spark) - } - - override def afterAll(): Unit = { - SparkSession.clearActiveSession() - super.afterAll() - } - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( - rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) - (rateSource, offset) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("compatible with old path in registry") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - assert(ds.isInstanceOf[RateStreamProvider]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("microbatch - basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) - } - assert(data.size === 20) - } - - test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("valueAtSecond") { - import RateStreamProvider._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[TreeNodeException[_]](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getCause.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[IllegalArgumentException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - for (msg <- expectedMessages) { - assert(e.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find read support for continuous rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} From ea2fdc0d286e449884de44f22a908a26ab1248a5 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Wed, 28 Mar 2018 19:49:32 -0500 Subject: [PATCH 223/593] [SPARK-23675][WEB-UI] Title add spark logo, use spark logo image ## What changes were proposed in this pull request? Title add spark logo, use spark logo image. reference other big data system ui, so i think spark should add it. spark fix before: ![spark_fix_before](https://user-images.githubusercontent.com/26266482/37387866-2d5add0e-2799-11e8-9165-250f2b59df3f.png) spark fix after: ![spark_fix_after](https://user-images.githubusercontent.com/26266482/37387874-329e1876-2799-11e8-8bc5-c619fc1e680e.png) reference kafka ui: ![kafka](https://user-images.githubusercontent.com/26266482/37387878-35ca89d0-2799-11e8-834e-1598ae7158e1.png) reference storm ui: ![storm](https://user-images.githubusercontent.com/26266482/37387880-3854f12c-2799-11e8-8968-b428ba361995.png) reference yarn ui: ![yarn](https://user-images.githubusercontent.com/26266482/37387881-3a72e130-2799-11e8-97bb-dea85f573e95.png) reference nifi ui: ![nifi](https://user-images.githubusercontent.com/26266482/37387887-3cecfea0-2799-11e8-9a71-6c454d25840b.png) reference flink ui: ![flink](https://user-images.githubusercontent.com/26266482/37387888-3f16b1ee-2799-11e8-9d37-8355f0100548.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20818 from guoxiaolongzte/SPARK-23675. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ba798df13c95d..02cf19e00ecde 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -224,6 +224,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (showVisualization) vizHeaderNodes else Seq.empty} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {appName} - {title} @@ -265,6 +266,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {title} From 641aec68e8167546dbb922874c086c9b90198f08 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 29 Mar 2018 16:37:46 +0800 Subject: [PATCH 224/593] =?UTF-8?q?[SPARK-23806]=20Broadcast.unpersist=20c?= =?UTF-8?q?an=20cause=20fatal=20exception=20when=20used=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with dynamic allocation ## What changes were proposed in this pull request? ignore errors when you are waiting for a broadcast.unpersist. This is handling it the same way as doing rdd.unpersist in https://issues.apache.org/jira/browse/SPARK-22618 ## How was this patch tested? Patch was tested manually against a couple jobs that exhibit this behavior, with the change the application no longer dies due to this and just prints the warning. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Thomas Graves Closes #20924 from tgravescs/SPARK-23806. --- .../spark/storage/BlockManagerMasterEndpoint.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 89a6a71a589a1..56b95c31eb4c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -192,11 +192,15 @@ class BlockManagerMasterEndpoint( val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + val futures = requiredBlockManagers.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove broadcast $broadcastId", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeBlockManager(blockManagerId: BlockManagerId) { From 505480cb578af9f23acc77bc82348afc9d8468e8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 29 Mar 2018 19:38:28 +0900 Subject: [PATCH 225/593] [SPARK-23770][R] Exposes repartitionByRange in SparkR ## What changes were proposed in this pull request? This PR proposes to expose `repartitionByRange`. ```R > df <- createDataFrame(iris) ... > getNumPartitions(repartitionByRange(df, 3, col = df$Species)) [1] 3 ``` ## How was this patch tested? Manually tested and the unit tests were added. The diff with `repartition` can be checked as below: ```R > df <- createDataFrame(mtcars) > take(repartition(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 2 10.4 8 460.0 215 3.00 5.424 17.82 0 0 3 4 3 32.4 4 78.7 66 4.08 2.200 19.47 1 1 4 1 > take(repartitionByRange(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 30.4 4 75.7 52 4.93 1.615 18.52 1 1 4 2 2 33.9 4 71.1 65 4.22 1.835 19.90 1 1 4 1 3 27.3 4 79.0 66 4.08 1.935 18.90 1 1 4 1 ``` Author: hyukjinkwon Closes #20902 from HyukjinKwon/r-repartitionByRange. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 65 ++++++++++++++++++++++++++- R/pkg/R/generics.R | 3 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 45 +++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c51eb0f39c4b1..190c50ea10482 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -151,6 +151,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "repartitionByRange", "rollup", "sample", "sample_frac", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c4852024c0f49..a1c9495b0795e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -687,7 +687,7 @@ setMethod("storageLevel", #' @rdname coalesce #' @name coalesce #' @aliases coalesce,SparkDataFrame-method -#' @seealso \link{repartition} +#' @seealso \link{repartition}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -723,7 +723,7 @@ setMethod("coalesce", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method -#' @seealso \link{coalesce} +#' @seealso \link{coalesce}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -759,6 +759,67 @@ setMethod("repartition", dataFrame(sdf) }) + +#' Repartition by range +#' +#' The following options for repartition by range are possible: +#' \itemize{ +#' \item{1.} {Return a new SparkDataFrame range partitioned by +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame range partitioned by the given column(s), +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} +#'} +#' +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the range partitioning will be performed. +#' @param ... additional column(s) to be used in the range partitioning. +#' +#' @family SparkDataFrame functions +#' @rdname repartitionByRange +#' @name repartitionByRange +#' @aliases repartitionByRange,SparkDataFrame-method +#' @seealso \link{repartition}, \link{coalesce} +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- repartitionByRange(df, col = df$col1, df$col2) +#' newDF <- repartitionByRange(df, 3L, col = df$col1, df$col2) +#'} +#' @note repartitionByRange since 2.4.0 +setMethod("repartitionByRange", + signature(x = "SparkDataFrame"), + function(x, numPartitions = NULL, col = NULL, ...) { + if (!is.null(numPartitions) && !is.null(col)) { + # number of partitions and columns both are specified + if (is.numeric(numPartitions) && class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", numToInt(numPartitions), jcol) + } else { + stop(paste("numPartitions and col must be numeric and Column; however, got", + class(numPartitions), "and", class(col))) + } + } else if (!is.null(col)) { + # only columns are specified + if (class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", jcol) + } else { + stop(paste("col must be Column; however, got", class(col))) + } + } else if (!is.null(numPartitions)) { + # only numPartitions is specified + stop("At least one partition-by column must be specified.") + } else { + stop("Please, specify a column(s) or the number of partitions with a column(s)") + } + dataFrame(sdf) + }) + #' toJSON #' #' Converts a SparkDataFrame into a SparkDataFrame of JSON string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6fba4b6c761dd..974beff1a3d76 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -531,6 +531,9 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) +#' @rdname repartitionByRange +setGeneric("repartitionByRange", function(x, ...) { standardGeneric("repartitionByRange") }) + #' @rdname sample setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 439191adb23ea..7105469ffc242 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3104,6 +3104,51 @@ test_that("repartition by columns on DataFrame", { }) }) +test_that("repartitionByRange on a DataFrame", { + # The tasks here launch R workers with shuffles. So, we decrease the number of shuffle + # partitions to reduce the number of the tasks to speed up the test. This is particularly + # slow on Windows because the R workers are unable to be forked. See also SPARK-21693. + conf <- callJMethod(sparkSession, "conf") + shufflepartitionsvalue <- callJMethod(conf, "get", "spark.sql.shuffle.partitions") + callJMethod(conf, "set", "spark.sql.shuffle.partitions", "5") + tryCatch({ + df <- createDataFrame(mtcars) + expect_error(repartitionByRange(df, "haha", df$mpg), + "numPartitions and col must be numeric and Column.*") + expect_error(repartitionByRange(df), + ".*specify a column.*or the number of partitions with a column.*") + expect_error(repartitionByRange(df, col = "haha"), + "col must be Column; however, got.*") + expect_error(repartitionByRange(df, 3), + "At least one partition-by column must be specified.") + + # The order of rows should be different with a normal repartition. + actual <- repartitionByRange(df, 3, df$mpg) + expect_equal(getNumPartitions(actual), 3) + expect_false(identical(collect(actual), collect(repartition(df, 3, df$mpg)))) + + actual <- repartitionByRange(df, col = df$mpg) + expect_false(identical(collect(actual), collect(repartition(df, col = df$mpg)))) + + # They should have same data. + actual <- collect(repartitionByRange(df, 3, df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, 3, df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + + actual <- collect(repartitionByRange(df, col = df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, col = df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.shuffle.partitions", shufflepartitionsvalue) + }) +}) + test_that("coalesce, repartition, numPartitions", { df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) From 491ec114fd3886ebd9fa29a482e3d112fb5a088c Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 29 Mar 2018 10:23:23 -0700 Subject: [PATCH 226/593] [SPARK-23785][LAUNCHER] LauncherBackend doesn't check state of connection before setting state ## What changes were proposed in this pull request? Changed `LauncherBackend` `set` method so that it checks if the connection is open or not before writing to it (uses `isConnected`). ## How was this patch tested? None Author: Sahil Takiar Closes #20893 from sahilTakiar/master. --- .../spark/launcher/LauncherBackend.scala | 6 +++--- .../spark/launcher/LauncherServerSuite.java | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index aaae33ca4e6f3..1b049b786023a 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -67,13 +67,13 @@ private[spark] abstract class LauncherBackend { } def setAppId(appId: String): Unit = { - if (connection != null) { + if (connection != null && isConnected) { connection.send(new SetAppId(appId)) } } def setState(state: SparkAppHandle.State): Unit = { - if (connection != null && lastState != state) { + if (connection != null && isConnected && lastState != state) { connection.send(new SetState(state)) lastState = state } @@ -114,10 +114,10 @@ private[spark] abstract class LauncherBackend { override def close(): Unit = { try { + _isConnected = false super.close() } finally { onDisconnected() - _isConnected = false } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index d16337a319be3..5413d3a416545 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -185,6 +185,26 @@ public void testStreamFiltering() throws Exception { } } + @Test + public void testAppHandleDisconnect() throws Exception { + LauncherServer server = LauncherServer.getOrCreateServer(); + ChildProcAppHandle handle = new ChildProcAppHandle(server); + String secret = server.registerHandle(handle); + + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); + client = new TestClient(s); + client.send(new Hello(secret, "1.4.0")); + handle.disconnect(); + waitForError(client, secret); + } finally { + handle.kill(); + close(client); + client.clientThread.join(); + } + } + private void close(Closeable c) { if (c != null) { try { From a7755fd8ce2f022118b9827aaac7d5d59f0f297a Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 29 Mar 2018 10:46:28 -0700 Subject: [PATCH 227/593] [SPARK-23639][SQL] Obtain token before init metastore client in SparkSQL CLI ## What changes were proposed in this pull request? In SparkSQLCLI, SessionState generates before SparkContext instantiating. When we use --proxy-user to impersonate, it's unable to initializing a metastore client to talk to the secured metastore for no kerberos ticket. This PR use real user ugi to obtain token for owner before talking to kerberized metastore. ## How was this patch tested? Manually verified with kerberized hive metasotre / hdfs. Author: Kent Yao Closes #20784 from yaooqinn/SPARK-23639. --- .../deploy/security/HiveDelegationTokenProvider.scala | 8 ++++---- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index ece5ce79c650d..7249eb85ac7c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils -private[security] class HiveDelegationTokenProvider +private[spark] class HiveDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" @@ -124,9 +124,9 @@ private[security] class HiveDelegationTokenProvider val currentUser = UserGroupInformation.getCurrentUser() val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) - // For some reason the Scala-generated anonymous class ends up causing an - // UndeclaredThrowableException, even if you annotate the method with @throws. - try { + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { realUser.doAs(new PrivilegedExceptionAction[T]() { override def run(): T = fn }) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 832a15d09599f..084f8200102ba 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,13 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HiveDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -121,6 +123,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { } } + val tokenProvider = new HiveDelegationTokenProvider() + if (tokenProvider.delegationTokensRequired(sparkConf, hadoopConf)) { + val credentials = new Credentials() + tokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + UserGroupInformation.getCurrentUser.addCredentials(credentials) + } + SessionState.start(sessionState) // Clean up after we exit From b348901192b231153b58fe5720253168c87963d4 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 29 Mar 2018 21:36:56 -0700 Subject: [PATCH 228/593] [SPARK-23808][SQL] Set default Spark session in test-only spark sessions. ## What changes were proposed in this pull request? Set default Spark session in the TestSparkSession and TestHiveSparkSession constructors. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20926 from jose-torres/test3. --- .../spark/sql/test/TestSQLContext.scala | 2 ++ .../sql/test/TestSparkSessionSuite.scala | 29 +++++++++++++++++++ .../apache/spark/sql/hive/test/TestHive.scala | 4 +++ 3 files changed, 35 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 4286e8a6ca2c8..3038b822beb4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -34,6 +34,8 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) this(new SparkConf) } + SparkSession.setDefaultSession(this) + @transient override lazy val sessionState: SessionState = { new TestSQLSessionStateBuilder(this, None).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala new file mode 100644 index 0000000000000..4019c6888da98 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession + +class TestSparkSessionSuite extends SparkFunSuite { + test("default session is set in constructor") { + val session = new TestSparkSession() + assert(SparkSession.getDefaultSession.contains(session)) + session.stop() + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index fcf2025d34432..814038d4ef7af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,6 +159,10 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => + // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as + // TestSparkSession, but doing this the same way breaks many tests in the package. We need + // to investigate and find a different strategy. + def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From df05fb63abe6018ccbe572c34cf65fc3ecbf1166 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 30 Mar 2018 14:07:35 +0800 Subject: [PATCH 229/593] [SPARK-23743][SQL] Changed a comparison logic from containing 'slf4j' to starting with 'org.slf4j' ## What changes were proposed in this pull request? isSharedClass returns if some classes can/should be shared or not. It checks if the classes names have some keywords or start with some names. Following the logic, it can occur unintended behaviors when a custom package has `slf4j` inside the package or class name. As I guess, the first intention seems to figure out the class containing `org.slf4j`. It would be better to change the comparison logic to `name.startsWith("org.slf4j")` ## How was this patch tested? This patch should pass all of the current tests and keep all of the current behaviors. In my case, I'm using ProtobufDeserializer to get a table schema from hive tables. Thus some Protobuf packages and names have `slf4j` inside. Without this patch, it cannot be resolved because of ClassCastException from different classloaders. Author: Jongyoul Lee Closes #20860 from jongyoul/SPARK-23743. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 12975bc85b971..c2690ec32b9e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -179,8 +179,9 @@ private[hive] class IsolatedClientLoader( val isHadoopClass = name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") - name.contains("slf4j") || - name.contains("log4j") || + name.startsWith("org.slf4j") || + name.startsWith("org.apache.log4j") || // log4j1.x + name.startsWith("org.apache.logging.log4j") || // log4j2 name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || From b02e76cbffe9e589b7a4e60f91250ca12a4420b2 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 30 Mar 2018 15:07:38 +0800 Subject: [PATCH 230/593] [SPARK-23727][SQL] Support for pushing down filters for DateType in parquet ## What changes were proposed in this pull request? This PR supports for pushing down filters for DateType in parquet ## How was this patch tested? Added UT and tested in local. Author: yucai Closes #20851 from yucai/SPARK-23727. --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../datasources/parquet/ParquetFilters.scala | 33 ++++++++++++ .../parquet/ParquetFilterSuite.scala | 50 +++++++++++++++++-- 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9cb03b5bb6152..13f31a6b2eb93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -353,6 +353,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") + .doc("If true, enables Parquet filter push-down optimization for Date. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1329,6 +1336,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 763841efbd9f3..ccc8306866d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.sql.Date + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ @@ -29,6 +34,10 @@ import org.apache.spark.sql.types._ */ private[parquet] object ParquetFilters { + private def dateToDays(date: Date): SQLDate = { + DateTimeUtils.fromJavaDate(date) + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -50,6 +59,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -72,6 +85,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -91,6 +108,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.lt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -110,6 +131,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.ltEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -129,6 +154,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -148,6 +177,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gtEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 33801954ebd51..1d3476e747046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -76,8 +77,10 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -102,7 +105,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex maybeFilter.exists(_.getClass === filterClass) } checker(stripSparkFilter(query), expected) - } } } @@ -313,6 +315,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - date") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") + + withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate( + Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate( + Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) + checkFilterPredicate( + '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From 5b5a36ed6d2bb0971edfeccddf0f280936d2275f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 30 Mar 2018 21:54:26 +0800 Subject: [PATCH 231/593] Roll forward "[SPARK-23096][SS] Migrate rate source to V2" ## What changes were proposed in this pull request? Roll forward c68ec4e (#20688). There are two minor test changes required: * An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException. * The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils. ## How was this patch tested? existing tests Author: Jose Torres Author: jerryshao Closes #20922 from jose-torres/ratefix. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------------ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++++ .../sources/RateStreamSourceV2.scala | 187 ------------- .../streaming/RateSourceV2Suite.scala | 191 ------------- .../RateStreamProviderSuite.scala} | 166 ++++++++++- 9 files changed, 524 insertions(+), 663 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{RateSourceSuite.scala => sources/RateStreamProviderSuite.scala} (50%) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala similarity index 50% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 03d0f63fa4d7f..ff14ec38e66a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -15,13 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources +import java.nio.file.Files +import java.util.Optional import java.util.concurrent.TimeUnit -import org.apache.spark.sql.AnalysisException +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock class RateSourceSuite extends StreamTest { @@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") } } - test("basic") { + test("microbatch - basic") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "10") @@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest { ) } - test("uniform distribution of event timestamps") { + test("microbatch - uniform distribution of event timestamps") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "1500") @@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + test("valueAtSecond") { - import RateStreamSource._ + import RateStreamProvider._ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) @@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest { option: String, value: String, expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { + val e = intercept[IllegalArgumentException] { spark.readStream .format("rate") .option(option, value) @@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest { .start() .awaitTermination() } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } } @@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest { assert(exception.getMessage.contains( "rate source does not support a user-specified schema")) } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } } From bc8d0931170cfa20a4fb64b3b11a2027ddb0d6e9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 30 Mar 2018 23:21:07 +0800 Subject: [PATCH 232/593] [SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? This PR is to improve the test coverage of the original PR https://github.com/apache/spark/pull/20687 ## How was this patch tested? N/A Author: gatorsmile Closes #20911 from gatorsmile/addTests. --- .../optimizer/complexTypesSuite.scala | 176 ++++++++++++------ .../apache/spark/sql/ComplexTypesSuite.scala | 109 +++++++++++ 2 files changed, 233 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e44a6692ad8e2..21ed987627b3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { SimplifyExtractValueOps) :: Nil } - val idAtt = ('id).long.notNull - val nullableIdAtt = ('nullable_id).long + private val idAtt = ('id).long.notNull + private val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt, nullableIdAtt) + private val relation = LocalRelation(idAtt, nullableIdAtt) + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + + private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { + val optimized = Optimizer.execute(originalQuery.analyze) + assert(optimized.resolved, "optimized plans must be still resolvable") + comparePlans(optimized, correctAnswer.analyze) + } test("explicit get from namedStruct") { val query = relation @@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField( CreateNamedStruct(Seq("att", 'id )), 0, - None) as "outerAtt").analyze - val expected = relation.select('id as "outerAtt").analyze + None) as "outerAtt") + val expected = relation.select('id as "outerAtt") - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) - .analyze val expected = relation .select('id as "named_struct(att, id).att") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") .select(GetStructField('struct1, 0, None) as "struct1Att") - .analyze - val expected = relation.select('id as "struct1Att").analyze - comparePlans(Optimizer execute query, expected) + val expected = relation.select('id as "struct1Att") + checkRule(query, expected) } test("collapse multiple CreateNamedStruct/GetStructField pairs") { @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None) as "struct1Att1", GetStructField('struct1, 1, None) as "struct1Att2") - .analyze val expected = relation. select( 'id as "struct1Att1", ('id * 'id) as "struct1Att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed2 - deduced names") { @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None), GetStructField('struct1, 1, None)) - .analyze val expected = relation. select( 'id as "struct1.att1", ('id * 'id) as "struct1.att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplified array ops") { @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { 1, false), 1) as "a4") - .analyze val expected = relation .select( @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { "att2", (('id + 1L) * ('id + 1L)))) as "a2", ('id + 1L) as "a3", ('id + 1L) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("SPARK-22570: CreateArray should not create a lot of global variables") { @@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", GetMapValue('m, "r32") as "a3", GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") - .analyze val expected = relation.select( @@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ) ) as "a3", Literal.create(null, LongType) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, constant lookup, dynamic keys") { @@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo(13L, ('id + 1L)), ('id + 2L)), (EqualTo(13L, ('id + 2L)), ('id + 3L)), (Literal(true), 'id))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { @@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), ('id + 3L)) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( @@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), (Literal(true), ('id + 4L)))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, no positive match") { @@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 'id + 30L) as "a") - .analyze val expected = relation.select( CaseWhen(Seq( (EqualTo('id + 30L, 'id), ('id + 1L)), @@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, constant lookup, mixed keys, eliminated constants") { @@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 2L), ('id + 3L), ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, potential dynamic match with null value + an absolute constant match") { @@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 2L ) as "a") - .analyze val expected = relation .select( @@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. (Literal.TrueLiteral, 'id))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) + } + + test("SPARK-23500: Simplify array ops that are not at the top node") { + val query = LocalRelation('id.long) + .select( + CreateArray(Seq( + CreateNamedStruct(Seq( + "att1", 'id, + "att2", 'id * 'id)), + CreateNamedStruct(Seq( + "att1", 'id + 1, + "att2", ('id + 1) * ('id + 1)) + )) + ) as "arr") + .select( + GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetArrayItem( + GetArrayStructFields('arr, + StructField("att1", LongType, nullable = false), + ordinal = 0, + numFields = 1, + containsNull = false), + ordinal = 1) as "a2") + .orderBy('id.asc) + + val expected = LocalRelation('id.long) + .select( + ('id + 1L) as "a1", + ('id + 1L) as "a2") + .orderBy('id.asc) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify map ops that are not top nodes") { + val query = + LocalRelation('id.long) + .select( + CreateMap(Seq( + "r1", 'id, + "r2", 'id + 1L)) as "m") + .select( + GetMapValue('m, "r1") as "a1", + GetMapValue('m, "r32") as "a2") + .orderBy('id.asc) + .select('a1, 'a2) + + val expected = + LocalRelation('id.long).select( + 'id as "a1", + Literal.create(null, LongType) as "a2") + .orderBy('id.asc) + checkRule(query, expected) } test("SPARK-23500: Simplify complex ops that aren't at the plan root") { val structRel = relation .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") - .groupBy($"foo")("1").analyze + .groupBy($"foo")("1") val structExpected = relation .select('nullable_id as "foo") - .groupBy($"foo")("1").analyze - comparePlans(Optimizer execute structRel, structExpected) + .groupBy($"foo")("1") + checkRule(structRel, structExpected) // These tests must use nullable attributes from the base relation for the following reason: // in the 'original' plans below, the Aggregate node produced by groupBy() has a @@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") - .groupBy($"a1")("1").analyze - val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze - comparePlans(Optimizer execute arrayRel, arrayExpected) + .groupBy($"a1")("1") + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + checkRule(arrayRel, arrayExpected) val mapRel = relation .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") - .groupBy($"m1")("1").analyze + .groupBy($"m1")("1") val mapExpected = relation .select('nullable_id as "m1") - .groupBy($"m1")("1").analyze - comparePlans(Optimizer execute mapRel, mapExpected) + .groupBy($"m1")("1") + checkRule(mapRel, mapExpected) } test("SPARK-23500: Ensure that aggregation expressions are not simplified") { @@ -369,11 +407,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // grouping exprs so aren't tested here. val structAggRel = relation.groupBy( CreateNamedStruct(Seq("att1", 'nullable_id)))( - GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze - comparePlans(Optimizer execute structAggRel, structAggRel) + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + checkRule(structAggRel, structAggRel) val arrayAggRel = relation.groupBy( - CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze - comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + checkRule(arrayAggRel, arrayAggRel) + + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) + } + + test("SPARK-23500: namedStruct and getField in the same Project #1") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) + .select('s1 getField "col2" as 's1Col2, + namedStruct("col1", 'a, "col2", 'b).as("s2")) + .select('s1Col2, 's2 getField "col2" as 's2Col2) + val correctAnswer = + testRelation + .select('c as 's1Col2, 'b as 's2Col2) + checkRule(originalQuery, correctAnswer) + } + + test("SPARK-23500: namedStruct and getField in the same Project #2") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, + namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + val correctAnswer = + testRelation + .select('c as 'sCol2, 'a as 'sCol1) + checkRule(originalQuery, correctAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala new file mode 100644 index 0000000000000..b74fe2f90df23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSQLContext + +class ComplexTypesSuite extends QueryTest with SharedSQLContext { + + override def beforeAll() { + super.beforeAll() + spark.range(10).selectExpr( + "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5") + .write.saveAsTable("tab") + } + + override def afterAll() { + try { + spark.sql("DROP TABLE IF EXISTS tab") + } finally { + super.afterAll() + } + } + + def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = { + var count = 0 + plan.foreach { operator => + operator.transformExpressions { + case c: CreateNamedStruct => + count += 1 + c + } + } + + if (expectedCount != count) { + fail(s"expect $expectedCount CreateNamedStruct but got $count.") + } + } + + test("simple case") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2") + .filter("col2.c > 11").selectExpr("col1.a") + checkAnswer(df, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("named_struct is used in the top Project") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .selectExpr("col1.a", "col1") + .filter("col1.a > 8") + checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1) + + val df1 = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .sort("col1") + .selectExpr("col1.a") + .filter("col1.a > 8") + checkAnswer(df1, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1) + } + + test("expression in named_struct") { + val df = spark.table("tab") + .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola") + .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10") + checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola") + .selectExpr("cola.i3").filter("cola.exp > 10") + checkAnswer(df1, Row(12) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("nested case") { + val df = spark.table("tab") + .selectExpr("struct(struct(i2, i3) as exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10") + checkAnswer(df, Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("struct(i2, i3) as exp", "i4") + .selectExpr("struct(exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11") + checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } +} From ae9172017c361e5c1039bc2ca94048117021974a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 30 Mar 2018 14:09:14 -0700 Subject: [PATCH 233/593] [SPARK-23640][CORE] Fix hadoop config may override spark config ## What changes were proposed in this pull request? It may be get `spark.shuffle.service.port` from https://github.com/apache/spark/blob/9745ec3a61c99be59ef6a9d5eebd445e8af65b7a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala#L459 Therefore, the client configuration `spark.shuffle.service.port` does not working unless the configuration is `spark.hadoop.spark.shuffle.service.port`. - This configuration is not working: ``` bin/spark-sql --master yarn --conf spark.shuffle.service.port=7338 ``` - This configuration works: ``` bin/spark-sql --master yarn --conf spark.hadoop.spark.shuffle.service.port=7338 ``` This PR fix this issue. ## How was this patch tested? It's difficult to carry out unit testing. But I've tested it manually. Author: Yuming Wang Closes #20785 from wangyum/SPARK-23640. --- .../scala/org/apache/spark/util/Utils.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5caedeb526469..d2be93226e2a2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2302,16 +2302,20 @@ private[spark] object Utils extends Logging { } /** - * Return the value of a config either through the SparkConf or the Hadoop configuration - * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf - * if the key is not set in the Hadoop configuration. + * Return the value of a config either through the SparkConf or the Hadoop configuration. + * We Check whether the key is set in the SparkConf before look at any Hadoop configuration. + * If the key is set in SparkConf, no matter whether it is running on YARN or not, + * gets the value from SparkConf. + * Only when the key is not set in SparkConf and running on YARN, + * gets the value from Hadoop configuration. */ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { - val sparkValue = conf.get(key, default) - if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { - new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue) + if (conf.contains(key)) { + conf.get(key, default) + } else if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { + new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, default) } else { - sparkValue + default } } From 15298b99ac8944e781328423289586176cf824d7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 30 Mar 2018 16:48:26 -0700 Subject: [PATCH 234/593] [SPARK-23827][SS] StreamingJoinExec should ensure that input data is partitioned into specific number of partitions ## What changes were proposed in this pull request? Currently, the requiredChildDistribution does not specify the partitions. This can cause the weird corner cases where the child's distribution is `SinglePartition` which satisfies the required distribution of `ClusterDistribution(no-num-partition-requirement)`, thus eliminating the shuffle needed to repartition input data into the required number of partitions (i.e. same as state stores). That can lead to "file not found" errors on the state store delta files as the micro-batch-with-no-shuffle will not run certain tasks and therefore not generate the expected state store delta files. This PR adds the required constraint on the number of partitions. ## How was this patch tested? Modified test harness to always check that ANY stateful operator should have a constraint on the number of partitions. As part of that, the existing opt-in checks on child output partitioning were removed, as they are redundant. Author: Tathagata Das Closes #20941 from tdas/SPARK-23827. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 3 +- .../sql/streaming/DeduplicateSuite.scala | 8 +-- .../FlatMapGroupsWithStateSuite.scala | 5 +- .../sql/streaming/StatefulOperatorTest.scala | 49 ------------------- .../spark/sql/streaming/StreamTest.scala | 19 +++++++ .../streaming/StreamingAggregationSuite.scala | 4 +- 7 files changed, 25 insertions(+), 65 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a10ed5f2df1b5..1a83c884d55bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,7 +62,7 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } - private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index c351f658cb955..fa7c8ee906ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index caf2bab8a5859..0088b64d6195e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index de2b51678cea6..b1416bff87ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -42,8 +42,7 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { + with BeforeAndAfterAll { import testImplicits._ import GroupStateImpl._ @@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( - sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala deleted file mode 100644 index 45142278993bb..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.streaming._ - -trait StatefulOperatorTest { - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - colNames: Seq[String]): Boolean = { - val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output - val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions - val groupingAttr = attr.filter(a => colNames.contains(a.name)) - checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) - } - - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - expectedPartitioning: Partitioning): Boolean = { - val operator = sq.asInstanceOf[StreamExecution].lastExecution - .executedPlan.collect { case p: T => p } - operator.head.children.forall( - _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e44aef09f1f3c..00741d660dd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ @@ -444,6 +445,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } + val lastExecution = currentStream.lastExecution + if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { + // Verify if stateful operators have correct metadata and distribution + // This can often catch hard to debug errors when developing stateful operators + lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s => + assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores)) + s.requiredChildDistribution.foreach { d => + withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") { + assert(d.requiredNumPartitions.isDefined) + assert(d.requiredNumPartitions.get >= 1) + if (d != AllTuples) { + assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions) + } + } + } + } + } + val (latestBatchData, allData) = sink match { case s: MemorySink => (s.latestBatchData, s.allData) case s: MemorySinkV2 => (s.latestBatchData, s.allData) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97e065193fd05..1cae8cb8d47f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions with StatefulOperatorTest { + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), From 529f847105fa8d98a5dc4d20955e4870df6bc1c5 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 31 Mar 2018 10:34:01 +0800 Subject: [PATCH 235/593] [SPARK-23040][CORE][FOLLOW-UP] Avoid double wrap result Iterator. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20449#discussion_r172414393, If `resultIter` is already a `InterruptibleIterator`, don't double wrap it. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20920 from jiangxb1987/SPARK-23040. --- .../spark/shuffle/BlockStoreShuffleReader.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 85e7e56a04a7d..4103dfb10175e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -111,8 +111,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( case None => aggregatedIter } - // Use another interruptible iterator here to support task cancellation as aggregator or(and) - // sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } } } From 44a9f8e6e82c300dc61ca18515aee16f17f27501 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 2 Apr 2018 09:53:37 -0700 Subject: [PATCH 236/593] [SPARK-15009][PYTHON][FOLLOWUP] Add default param checks for CountVectorizerModel ## What changes were proposed in this pull request? Adding test for default params for `CountVectorizerModel` constructed from vocabulary. This required that the param `maxDF` be added, which was done in SPARK-23615. ## How was this patch tested? Added an explicit test for CountVectorizerModel in DefaultValuesTests. Author: Bryan Cutler Closes #20942 from BryanCutler/pyspark-CountVectorizerModel-default-param-test-SPARK-15009. --- python/pyspark/ml/tests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6b4376cbf14e8..c2c4861e2aff4 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2096,6 +2096,11 @@ def test_java_params(self): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector): From 6151f29f9f589301159482044fc32717f430db6e Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 2 Apr 2018 12:00:37 -0700 Subject: [PATCH 237/593] [SPARK-23825][K8S] Requesting memory + memory overhead for pod memory ## What changes were proposed in this pull request? Kubernetes driver and executor pods should request `memory + memoryOverhead` as their resources instead of just `memory`, see https://issues.apache.org/jira/browse/SPARK-23825 ## How was this patch tested? Existing unit tests were adapted. Author: David Vogelbacher Closes #20943 from dvogelbacher/spark-23825. --- .../k8s/submit/steps/BasicDriverConfigurationStep.scala | 5 +---- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 +---- .../submit/steps/BasicDriverConfigurationStepSuite.scala | 2 +- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 6 ++++-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 347c4d2d66826..b811db324108c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -93,9 +93,6 @@ private[spark] class BasicDriverConfigurationStep( .withAmount(driverCpuCores) .build() val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryMiB}Mi") - .build() - val driverMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => @@ -117,7 +114,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewResources() .addToRequests("cpu", driverCpuQuantity) .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryLimitQuantity) + .addToLimits("memory", driverMemoryQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 98cbd5607da00..ac42385459dda 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -108,9 +108,6 @@ private[spark] class ExecutorPodFactory( SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ executorLabels val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryMiB}Mi") - .build() - val executorMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) @@ -167,7 +164,7 @@ private[spark] class ExecutorPodFactory( .withImagePullPolicy(imagePullPolicy) .withNewResources() .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryLimitQuantity) + .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() .addAllToEnv(executorEnv.asJava) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index ce068531c7673..e59c6d28a8cc2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -91,7 +91,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val resourceRequirements = preparedDriverSpec.driverContainer.getResources val requests = resourceRequirements.getRequests.asScala assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "256Mi") + assert(requests("memory").getAmount === "456Mi") val limits = resourceRequirements.getLimits.asScala assert(limits("memory").getAmount === "456Mi") assert(limits("cpu").getAmount === "4") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7755b93835047..cee8fe27039c9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -66,12 +66,14 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef assert(executor.getMetadata.getLabels.size() === 3) assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - // There is exactly 1 container with no volume mounts and default memory limits. - // Default memory limit is 1024M + 384M (minimum overhead constant). + // There is exactly 1 container with no volume mounts and default memory limits and requests. + // Default memory limit/request is 1024M + 384M (minimum overhead constant). assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getImage === executorImage) assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") From fe2b7a4568d65a62da6e6eb00fff05f248b4332c Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Mon, 2 Apr 2018 12:20:55 -0700 Subject: [PATCH 238/593] [SPARK-23285][K8S] Add a config property for specifying physical executor cores ## What changes were proposed in this pull request? As mentioned in SPARK-23285, this PR introduces a new configuration property `spark.kubernetes.executor.cores` for specifying the physical CPU cores requested for each executor pod. This is to avoid changing the semantics of `spark.executor.cores` and `spark.task.cpus` and their role in task scheduling, task parallelism, dynamic resource allocation, etc. The new configuration property only determines the physical CPU cores available to an executor. An executor can still run multiple tasks simultaneously by using appropriate values for `spark.executor.cores` and `spark.task.cpus`. ## How was this patch tested? Unit tests. felixcheung srowen jiangxb1987 jerryshao mccheah foxish Author: Yinan Li Author: Yinan Li Closes #20553 from liyinan926/master. --- docs/running-on-kubernetes.md | 15 ++++++++--- .../org/apache/spark/deploy/k8s/Config.scala | 6 +++++ .../cluster/k8s/ExecutorPodFactory.scala | 12 ++++++--- .../cluster/k8s/ExecutorPodFactorySuite.scala | 27 +++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 975b28de47e20..9c4644947c911 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -549,14 +549,23 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + spark.kubernetes.executor.request.cores + (none) + + Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task + parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. + spark.kubernetes.executor.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. @@ -593,4 +602,4 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. - \ No newline at end of file + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index da34a7e06238a..405ea476351bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -91,6 +91,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = + ConfigBuilder("spark.kubernetes.executor.request.cores") + .doc("Specify the cpu request for each executor pod") + .stringConf + .createOptional + val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") .doc("Name of the driver pod.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ac42385459dda..7143f7a6f0b71 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -83,7 +83,12 @@ private[spark] class ExecutorPodFactory( MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorCores = sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) /** @@ -111,7 +116,7 @@ private[spark] class ExecutorPodFactory( .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCores.toString) + .withAmount(executorCoresRequest) .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() @@ -130,8 +135,7 @@ private[spark] class ExecutorPodFactory( }.getOrElse(Seq.empty[EnvVar]) val executorEnv = (Seq( (ENV_DRIVER_URL, driverUrl), - // Executor backend expects integral value for executor cores, so round it up to an int. - (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_CORES, executorCores.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), // This is to set the SPARK_CONF_DIR to be /opt/spark/conf diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index cee8fe27039c9..a71a2a1b888bc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -85,6 +85,33 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor core request specification") { + var factory = new ExecutorPodFactory(baseConf, None) + var executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "1") + + val conf = baseConf.clone() + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") + factory = new ExecutorPodFactory(conf, None) + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "0.1") + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + factory = new ExecutorPodFactory(conf, None) + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "100m") + } + test("executor pod hostnames get truncated to 63 characters") { val conf = baseConf.clone() conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, From a7c19d9c21d59fd0109a7078c80b33d3da03fafd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 2 Apr 2018 21:48:44 +0200 Subject: [PATCH 239/593] [SPARK-23713][SQL] Cleanup UnsafeWriter and BufferHolder classes ## What changes were proposed in this pull request? This PR implemented the following cleanups related to `UnsafeWriter` class: - Remove code duplication between `UnsafeRowWriter` and `UnsafeArrayWriter` - Make `BufferHolder` class internal by delegating its accessor methods to `UnsafeWriter` - Replace `UnsafeRow.setTotalSize(...)` with `UnsafeRowWriter.setTotalSize()` ## How was this patch tested? Tested by existing UTs Author: Kazuaki Ishizaki Closes #20850 from kiszk/SPARK-23713. --- .../sql/kafka010/KafkaContinuousReader.scala | 3 - .../KafkaRecordToUnsafeRowConverter.scala | 11 +- .../expressions/codegen/BufferHolder.java | 32 +-- .../codegen/UnsafeArrayWriter.java | 133 +++--------- .../expressions/codegen/UnsafeRowWriter.java | 189 +++++++----------- .../expressions/codegen/UnsafeWriter.java | 157 ++++++++++++++- .../InterpretedUnsafeProjection.scala | 90 ++++----- .../codegen/GenerateUnsafeProjection.scala | 124 +++++------- .../RowBasedKeyValueBatchSuite.java | 28 +-- .../aggregate/RowBasedHashMapGenerator.scala | 12 +- .../columnar/GenerateColumnAccessor.scala | 9 +- .../datasources/text/TextFileFormat.scala | 11 +- 12 files changed, 391 insertions(+), 408 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index e7e27876088f3..f26c134c2f6e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -27,13 +27,10 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String /** * A [[ContinuousReader]] for data from kafka. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 1acdd56125741..f35a143e00374 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val rowWriter = new UnsafeRowWriter(7) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - bufferHolder.reset() + rowWriter.reset() if (record.key == null) { rowWriter.setNullAt(0) @@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12f..537ef244b7e81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,25 +30,21 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +82,23 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] getBuffer() { + return buffer; + } + + int getCursor() { + return cursor; + } + + void increaseCursor(int val) { + cursor += val; + } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 82cd1b24607e1..a78dd970d23e4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -32,14 +30,12 @@ */ public final class UnsafeArrayWriter extends UnsafeWriter { - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; - // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { + super(writer.getBufferHolder()); + this.elementSize = elementSize; + } + + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,112 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } - holder.cursor += (headerInBytes + fixedPartInBytes); + increaseCursor(headerInBytes + fixedPartInBytes); } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } - } - - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + writeByte(getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + writeShort(getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + writeInt(getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + writeLong(getElementOffset(ordinal), 0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + increaseCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 2620bbcfb87a2..71c49d8ed0177 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -31,7 +28,7 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very @@ -40,29 +37,58 @@ */ public final class UnsafeRowWriter extends UnsafeWriter { - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; + private final UnsafeRow row; + private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); + } + + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ + public UnsafeRow getRow() { + row.setTotalSize(totalSize()); + return row; } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -72,25 +98,17 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); - } - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } @Override @@ -117,67 +135,49 @@ public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, int size) { - setOffsetAndSize(ordinal, holder.cursor, size); - } - - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | (long) size; - - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + writeLong(offset, 0L); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + writeLong(offset, 0L); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + writeLong(offset, 0L); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + writeLong(offset, 0L); + writeInt(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + writeLong(offset, 0); + writeFloat(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + if (input != null && input.changePrecision(precision, scale)) { + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -185,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + increaseCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index c94b5c7a367ef..de0eb6dbb76be 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -24,10 +26,73 @@ * Base class for writing Unsafe* structures. */ public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] getBuffer() { + return holder.getBuffer(); + } + + public final void reset() { + holder.reset(); + } + + public final int totalSize() { + return holder.totalSize(); + } + + public final void grow(int neededSize) { + holder.grow(neededSize); + } + + public final int cursor() { + return holder.getCursor(); + } + + public final void increaseCursor(int val) { + holder.increaseCursor(val); + } + + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); + } + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); public abstract void write(int ordinal, byte value); public abstract void write(int ordinal, short value); @@ -36,8 +101,92 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, float value); public abstract void write(int ordinal, double value); public abstract void write(int ordinal, Decimal input, int precision, int scale); - public abstract void write(int ordinal, UTF8String input); - public abstract void write(int ordinal, byte[] input); - public abstract void write(int ordinal, CalendarInterval input); - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + public final void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(getBuffer(), cursor()); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + increaseCursor(16); + } + + protected final void writeBoolean(long offset, boolean value) { + Platform.putBoolean(getBuffer(), offset, value); + } + + protected final void writeByte(long offset, byte value) { + Platform.putByte(getBuffer(), offset, value); + } + + protected final void writeShort(long offset, short value) { + Platform.putShort(getBuffer(), offset, value); + } + + protected final void writeInt(long offset, int value) { + Platform.putInt(getBuffer(), offset, value); + } + + protected final void writeLong(long offset, long value) { + Platform.putLong(getBuffer(), offset, value); + } + + protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(getBuffer(), offset, value); + } + + protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(getBuffer(), offset, value); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 0da5ece7e47fe..b31466f5c92d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row representing the expression results. */ private[this] val intermediate = new GenericInternalRow(values) - /** The row returned by the projection. */ - private[this] val result = new UnsafeRow(numFields) - - /** The buffer which holds the resulting row's backing data. */ - private[this] val holder = new BufferHolder(result, numFields * 32) + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { - val rowWriter = new UnsafeRowWriter(holder, numFields) val baseWriter = generateStructWriter( - holder, rowWriter, expressions.map(e => StructField("", e.dataType, e.nullable))) if (!expressions.exists(_.nullable)) { @@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } // Write the intermediate row to an unsafe row. - holder.reset() + rowWriter.reset() writer(intermediate) - result.setTotalSize(holder.totalSize()) - result + rowWriter.getRow() } } @@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * given buffer using the given [[UnsafeRowWriter]]. */ private def generateStructWriter( - bufferHolder: BufferHolder, rowWriter: UnsafeRowWriter, fields: Array[StructField]): InternalRow => Unit = { val numFields = fields.length // Create field writers. val fieldWriters = fields.map { field => - generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + generateFieldWriter(rowWriter, field.dataType, field.nullable) } // Create basic writer. row => { @@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * or array) to the given buffer using the given [[UnsafeWriter]]. */ private def generateFieldWriter( - bufferHolder: BufferHolder, writer: UnsafeWriter, dt: DataType, nullable: Boolean): (SpecializedGetters, Int) => Unit = { @@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case StructType(fields) => val numFields = fields.length - val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) - val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( - bufferHolder, + rowWriter, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) case row => // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.reset() + rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter - val elementSize = getElementSize(elementType) + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) val elementWriter = generateFieldWriter( - bufferHolder, arrayWriter, elementType, containsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor - writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter - val keySize = getElementSize(keyType) + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) val keyWriter = generateFieldWriter( - bufferHolder, keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter - val valueSize = getElementSize(valueType) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) val valueWriter = generateFieldWriter( - bufferHolder, valueArrayWriter, valueType, valueContainsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( - bufferHolder, + valueArrayWriter, map.getBaseObject, map.getBaseOffset, map.getSizeInBytes) case map => // preserve 8 bytes to write the key array numBytes later. - bufferHolder.grow(8) - bufferHolder.cursor += 8 + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) - Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) // Write the values. - writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => - generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + generateFieldWriter(writer, udt.sqlType, nullable) case NullType => (_, _) => {} @@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * copy. */ private def writeArray( - bufferHolder: BufferHolder, arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, - array: ArrayData, - elementSize: Int): Unit = array match { + array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( - bufferHolder, + arrayWriter, unsafe.getBaseObject, unsafe.getBaseOffset, unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(bufferHolder, numElements, elementSize) + arrayWriter.initialize(numElements) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. */ private def writeUnsafeData( - bufferHolder: BufferHolder, + writer: UnsafeWriter, baseObject: AnyRef, baseOffset: Long, sizeInBytes: Int) : Unit = { - bufferHolder.grow(sizeInBytes) + writer.grow(sizeInBytes) Platform.copyMemory( baseObject, baseOffset, - bufferHolder.buffer, - bufferHolder.cursor, + writer.getBuffer, + writer.cursor, sizeInBytes) - bufferHolder.cursor += sizeInBytes + writer.increaseCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6682ba55b18b1..ab2254cd9f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + s""" final InternalRow $tmpInput = $input; if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} } """ } @@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case udt: UserDefinedType[_] => udt.sqlType case other => other } - val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } + val previousCursor = ctx.freshName("previousCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => 8 // we need 8 bytes to store offset and length } - val tmpCursor = ctx.freshName("tmpCursor") + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val previousCursor = ctx.freshName("previousCursor") + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeArrayToBuffer(ctx, element, et, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") @@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final MapData $tmpInput = $input; if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} } else { // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + $rowWriter.grow(8); + $rowWriter.increaseCursor(8); // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + final int $tmpCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } """ } @@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; + $rowWriter.grow($sizeInBytes); + $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); + $rowWriter.increaseCursor($sizeInBytes); """ } @@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") - - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$holder.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$result.setTotalSize($holder.totalSize());" - } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" - $resetBufferHolder + $rowWriter.reset(); $evalSubexpr $writeExpressions - $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", s"$rowWriter.getRow()") } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed1996..2da87113c6229 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +54,27 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 8617be88f3570..d5508275c48c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -165,18 +165,14 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 3b5655ba0582e..2d699e8a9d088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9647f09867643..e93908da43535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() + unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + unsafeRowWriter.getRow() } } } From 28ea4e3142b88eb396aa8dd5daf7b02b556204ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 2 Apr 2018 14:35:07 -0700 Subject: [PATCH 240/593] [SPARK-23834][TEST] Wait for connection before disconnect in LauncherServer test. It was possible that the disconnect() was called on the handle before the server had received the handshake messages, so no connection was yet attached to the handle. The fix waits until we're sure the handle has been mapped to a client connection. Author: Marcelo Vanzin Closes #20950 from vanzin/SPARK-23834. --- .../org/apache/spark/launcher/LauncherServerSuite.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5413d3a416545..f8dc0ec7a0bf6 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -196,6 +196,14 @@ public void testAppHandleDisconnect() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); client = new TestClient(s); client.send(new Hello(secret, "1.4.0")); + client.send(new SetAppId("someId")); + + // Wait until we know the server has received the messages and matched the handle to the + // connection before disconnecting. + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + assertEquals("someId", handle.getAppId()); + }); + handle.disconnect(); waitForError(client, secret); } finally { From a1351828d376a01e5ee0959cf608f767d756dd86 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 2 Apr 2018 16:41:26 -0700 Subject: [PATCH 241/593] [SPARK-23690][ML] Add handleinvalid to VectorAssembler ## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg Author: Bago Amirbekian Author: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Closes #20829 from yogeshg/rformula_handleinvalid. --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 198 ++++++++++++++---- .../ml/feature/VectorAssemblerSuite.scala | 131 +++++++++++- 3 files changed, 284 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab78..67cdb097217a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -234,7 +234,7 @@ class StringIndexerModel ( val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { + val (filteredDataset, keepInvalid) = $(handleInvalid) match { case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b373ae921ed38..6bf4aa38b1fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuilder +import java.util.NoSuchElementException + +import scala.collection.mutable +import scala.language.existentials import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. + * + * This requires one pass over the entire dataset. In case we need to infer column lengths from the + * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.4.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + @Since("2.4.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + |output). Column lengths are taken from the size of ML Attribute Group, which can be set using + |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + |""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val vectorCols = $(inputCols).filter { c => + schema(c).dataType match { + case _: VectorUDT => true + case _ => false + } + } + val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + + val featureAttributesMap = $(inputCols).map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) + Seq(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + if (attributeGroup.attributes.isDefined) { + attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until vectorColsLengths(c)).map { i => + NumericAttribute.defaultAttr.withName(c + "_" + i) + } } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute].toArray + val lengths = featureAttributesMap.map(a => a.length).toArray + val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) + case VectorAssembler.KEEP_INVALID => (dataset, true) + case VectorAssembler.ERROR_INVALID => (dataset, false) + } // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths( + dataset: Dataset[_], + columns: Seq[String], + handleInvalid: String): Map[String, Int] = { + val groupSizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq + val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missingColumns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint + |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + .stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + groupSizes ++ firstSizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 + /** + * Returns a function that has the required information to assemble each row. + * @param lengths an array of lengths of input columns, whose size should be equal to the number + * of cells in the row (vv) + * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows + * @return a udf that can be applied on each row + */ + private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { + val indices = mutable.ArrayBuilder.make[Int] + val values = mutable.ArrayBuilder.make[Double] + var featureIndex = 0 + + var inputColumnIndex = 0 vv.foreach { case v: Double => - if (v != 0.0) { - indices += cur + if (v.isNaN && !keepInvalid) { + throw new SparkException( + s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider + |removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } else if (v != 0.0) { + indices += featureIndex values += v } - cur += 1 + inputColumnIndex += 1 + featureIndex += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { - indices += cur + i + indices += featureIndex + i values += v } } - cur += vec.size + inputColumnIndex += 1 + featureIndex += vec.size case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") + if (keepInvalid) { + val length: Int = lengths(inputColumnIndex) + Array.range(0, length).foreach { i => + indices += featureIndex + i + values += Double.NaN + } + inputColumnIndex += 1 + featureIndex += length + } else { + throw new SparkException( + s"""Encountered null while assembling a row with handleInvalid = "keep". Consider + |removing nulls from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(featureIndex, indices.result(), values.result()).compressed } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index eca065f7e775d..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite @@ -31,30 +31,49 @@ class VectorAssemblerSuite import testImplicits._ + @transient var dfWithNullsAndNaNs: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sv = Vectors.sparse(2, Array(1), Array(3.0)) + dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)]( + (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (2, 1, 0.0, null, "a", sv, 6L, null), + (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null), + (4, 4, null, null, "a", sv, 9L, null), + (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null)) + .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls") + } + test("params") { ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble - assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) - assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + assert(assemble(Array(1), keepInvalid = true)(0.0) + === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0) + === Vectors.sparse(2, Array(1), Array(1.0))) val dv = Vectors.dense(2.0, 0.0) - assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) === + Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) - assert(assemble(0.0, dv, 1.0, sv) === + assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) === Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) - for (v <- Seq(1, "a", null)) { - intercept[SparkException](assemble(v)) - intercept[SparkException](assemble(1.0, v)) + for (v <- Seq(1, "a")) { + intercept[SparkException](assemble(Array(1), keepInvalid = true)(v)) + intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v)) } } test("assemble should compress vectors") { import org.apache.spark.ml.feature.VectorAssembler.assemble - val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) assert(v1.isInstanceOf[SparseVector]) - val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + val sv = Vectors.sparse(1, Array(0), Array(4.0)) + val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv) assert(v2.isInstanceOf[DenseVector]) } @@ -147,4 +166,94 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null) + === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1), keepInvalid = false)(null)) + intercept[SparkException](assemble(Array(2), keepInvalid = false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNullsAndNaNs + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val attributeZ = new AttributeGroup( + "z", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata()) + .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter) + val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata) + output.collect() + output + } + + def runWithFirstRow(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs) + output.collect() + output + } + + def runWithAllNullVectors(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode) + .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2")) + output.collect() + output + } + + // behavior when vector size hint is given + assert(runWithMetadata("keep").count() == 6, "should keep all rows") + assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls") + // should throw error with nulls + intercept[SparkException](runWithMetadata("error")) + // should throw error with NaNs + intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4")) + + // behavior when first row has information + assert(intercept[RuntimeException](runWithFirstRow("keep").count()) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls") + intercept[SparkException](runWithFirstRow("error")) + + // behavior when vector column is all null + assert(intercept[RuntimeException](runWithAllNullVectors("skip")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(intercept[NullPointerException](runWithAllNullVectors("error")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + + // behavior when scalar column is all null + assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) + } + } From 441d0d0766e9a6ac4c6ff79680394999ff7191fd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 3 Apr 2018 09:31:47 +0800 Subject: [PATCH 242/593] [SPARK-19964][CORE] Avoid reading from remote repos in SparkSubmitSuite. These tests can fail with a timeout if the remote repos are not responding, or slow. The tests don't need anything from those repos, so use an empty ivy config file to avoid setting up the defaults. The tests are passing reliably for me locally now, and failing more often than not today without this change since http://dl.bintray.com/spark-packages/maven doesn't seem to be loading from my machine. Author: Marcelo Vanzin Closes #20916 from vanzin/SPARK-19964. --- .../org/apache/spark/deploy/DependencyUtils.scala | 13 ++++++++----- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- .../apache/spark/deploy/SparkSubmitArguments.scala | 2 ++ .../apache/spark/deploy/worker/DriverWrapper.scala | 13 +++++++++---- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ab319c860ee69..fac834a70b893 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -33,7 +33,8 @@ private[deploy] object DependencyUtils { packagesExclusions: String, packages: String, repositories: String, - ivyRepoPath: String): String = { + ivyRepoPath: String, + ivySettingsPath: Option[String]): String = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { packagesExclusions.split(",") @@ -41,10 +42,12 @@ private[deploy] object DependencyUtils { Nil } // Create the IvySettings, either load from file or build defaults - val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + val ivySettings = ivySettingsPath match { + case Some(path) => + SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath)) + + case None => + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) } SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3965f17f4b56e..eddbedeb1024d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -359,7 +359,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( - args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index e7796d4ddbe34..8e7070593687b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var ivySettingsPath: Option[String] = None var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false @@ -184,6 +185,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b19c9904d5982..3f71237164a15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -79,12 +79,17 @@ object DriverWrapper extends Logging { val secMgr = new SecurityManager(sparkConf) val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) - val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = - Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") - .map(sys.props.get(_).orNull) + val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) = + Seq( + "spark.jars.excludes", + "spark.jars.packages", + "spark.jars.repositories", + "spark.jars.ivy", + "spark.jars.ivySettings" + ).map(sys.props.get(_).orNull) val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, - packages, repositories, ivyRepoPath) + packages, repositories, ivyRepoPath, Option(ivySettingsPath)) val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d86ef907b4492..0d7c342a5eacd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,9 @@ class SparkSubmitSuite // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val emptyIvySettings = File.createTempFile("ivy", ".xml") + FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + override def beforeEach() { super.beforeEach() } @@ -520,6 +523,7 @@ class SparkSubmitSuite "--repositories", repo, "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -530,7 +534,6 @@ class SparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), @@ -540,6 +543,7 @@ class SparkSubmitSuite "--conf", s"spark.jars.repositories=$repo", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -550,7 +554,6 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -563,6 +566,7 @@ class SparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--packages", main.toString, "--repositories", repo, + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) @@ -573,7 +577,6 @@ class SparkSubmitSuite test("include an external JAR in SparkR") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) From 8020f66fc47140a1b5f843fb18c34ec80541d5ca Mon Sep 17 00:00:00 2001 From: lemonjing <932191671@qq.com> Date: Tue, 3 Apr 2018 09:36:44 +0800 Subject: [PATCH 243/593] [MINOR][DOC] Fix a few markdown typos ## What changes were proposed in this pull request? Easy fix in the markdown. ## How was this patch tested? jekyII build test manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: lemonjing <932191671@qq.com> Closes #20897 from Lemonjing/master. --- docs/ml-guide.md | 2 +- docs/mllib-feature-extraction.md | 4 ++-- docs/mllib-pmml-model-export.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 702bcf748fc74..aea07be34cb86 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases. * The class and trait hierarchy for logistic regression model summaries was changed to be cleaner and better accommodate the addition of the multi-class summary. This is a breaking change for user code that casts a `LogisticRegressionTrainingSummary` to a -` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail (_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which will still work correctly for both multinomial and binary cases. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 75aea70601875..8b89296b14cdd 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -278,8 +278,8 @@ for details on the API. multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. -Qu8T948*1# -Denoting the `scalingVec` as "`w`," this transformation may be written as: + +Denoting the `scalingVec` as "`w`", this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index d3530908706d0..f567565437927 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API * Table of contents {:toc} -## `spark.mllib` supported models +## spark.mllib supported models `spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). @@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a - + From 7cf9fab33457ccc9b2d548f15dd5700d5e8d08ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 3 Apr 2018 21:26:49 +0800 Subject: [PATCH 244/593] [MINOR][CORE] Show block manager id when remove RDD/Broadcast fails. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20924#discussion_r177987175, show block manager id when remove RDD/Broadcast fails. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20960 from jiangxb1987/bmid. --- .../apache/spark/storage/BlockManagerMasterEndpoint.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56b95c31eb4c3..8e8f7d197c9ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint( val futures = blockManagerInfo.values.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId", e) + logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", + e) 0 // zero blocks were removed } }.toSeq @@ -195,7 +196,8 @@ class BlockManagerMasterEndpoint( val futures = requiredBlockManagers.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove broadcast $broadcastId", e) + logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + + s"${bm.blockManagerId}", e) 0 // zero blocks were removed } }.toSeq From 66a3a5a2dc83e03dedcee9839415c1ddc1fb8125 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Apr 2018 11:05:29 -0700 Subject: [PATCH 245/593] [SPARK-23099][SS] Migrate foreach sink to DataSourceV2 ## What changes were proposed in this pull request? Migrate foreach sink to DataSourceV2. Since the previous attempt at this PR #20552, we've changed and strictly defined the lifecycle of writer components. This means we no longer need the complicated lifecycle shim from that PR; it just naturally works. ## How was this patch tested? existing tests Author: Jose Torres Closes #20951 from jose-torres/foreach. --- .../sql/execution/streaming/ForeachSink.scala | 68 ----------- .../sources/ForeachWriterProvider.scala | 111 ++++++++++++++++++ .../sql/streaming/DataStreamWriter.scala | 4 +- .../ForeachWriterSuite.scala} | 83 ++++++------- .../sql/streaming/StreamingQuerySuite.scala | 1 + 5 files changed, 156 insertions(+), 111 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{ForeachSinkSuite.scala => sources/ForeachWriterSuite.scala} (77%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc54107f8b83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000000000..df5d69d57e36f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc903168cfa0..effc1471e8e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala similarity index 77% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index b249dd41a84a6..03bf71b3f4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.concurrent.ConcurrentLinkedQueue @@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -47,9 +48,9 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .start() def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { - import ForeachSinkSuite._ + import ForeachWriterSuite._ - val events = ForeachSinkSuite.allEvents() + val events = ForeachWriterSuite.allEvents() assert(events.size === 2) // one seq of events for each of the 2 partitions // Verify both seq of events have an Open event as the first event @@ -64,13 +65,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } // -- batch 0 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(1, 2, 3, 4) query.processAllAvailable() verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() verifyOutput(expectedVersion = 1, expectedData = 5 to 8) @@ -95,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf input.addData(1, 2, 3, 4) query.processAllAvailable() - var allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() // -- batch 1 --------------------------------------- input.addData(5, 6, 7, 8) query.processAllAvailable() - allEvents = ForeachSinkSuite.allEvents() + allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) @@ -131,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("error") + throw new RuntimeException("ForeachSinkSuite error") } }).start() input.addData(1, 2, 3, 4) @@ -141,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") assert(query.isActive === false) - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") } } @@ -177,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf inputData.addData(10, 11, 12) query.processAllAvailable() - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) } finally { @@ -216,21 +217,21 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 3) val expectedEvents = Seq( Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) ) assert(allEvents === expectedEvents) @@ -258,7 +259,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } /** A global object to collect events in the executor */ -object ForeachSinkSuite { +object ForeachWriterSuite { trait Event @@ -285,21 +286,21 @@ object ForeachSinkSuite { /** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) + events += ForeachWriterSuite.Open(partition = partitionId, version = version) true } override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) + events += ForeachWriterSuite.Process(value) } override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 08749b49997e0..20942ed93897c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.DataReaderFactory From 1035aaa61704b2790192d3186fe37e678553d36d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Apr 2018 01:36:58 +0200 Subject: [PATCH 246/593] [SPARK-23587][SQL] Add interpreted execution for MapObjects expression ## What changes were proposed in this pull request? Add interpreted execution for `MapObjects` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20771 from viirya/SPARK-23587. --- .../expressions/objects/objects.scala | 110 ++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 67 ++++++++++- 2 files changed, 165 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index adf9ddf327c96..0e9d357c19c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -501,12 +502,22 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + input.get(0, dataType) + } override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -599,8 +610,92 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[Seq[_]] + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].array + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + + "resulting collection.") + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( @@ -647,13 +742,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1f6964dfef598..0edd27c8241e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkFunSuite} @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) + + // Unsupported custom collection class + val errMsg = intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) + } + } + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { val cls = classOf[java.lang.Integer] val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) From 359375eff74630c9f0ea5a90ab7d45bf1b281ed0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 3 Apr 2018 17:09:12 -0700 Subject: [PATCH 247/593] [SPARK-23809][SQL] Active SparkSession should be set by getOrCreate ## What changes were proposed in this pull request? Currently, the active spark session is set inconsistently (e.g., in createDataFrame, prior to query execution). Many places in spark also incorrectly query active session when they should be calling activeSession.getOrElse(defaultSession) and so might get None even if a Spark session exists. The semantics here can be cleaned up if we also set the active session when the default session is set. Related: https://github.com/apache/spark/pull/20926/files ## How was this patch tested? Unit test, existing test. Note that if https://github.com/apache/spark/pull/20926 merges first we should also update the tests there. Author: Eric Liang Closes #20927 from ericl/active-session-cleanup. --- .../org/apache/spark/sql/SparkSession.scala | 14 +++++++++++++- .../spark/sql/SparkSessionBuilderSuite.scala | 18 ++++++++++++++++++ .../apache/spark/sql/test/TestSQLContext.scala | 1 + .../apache/spark/sql/hive/test/TestHive.scala | 3 +++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 734573ba31f71..b107492fbb330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -951,7 +951,8 @@ object SparkSession { session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } - defaultSession.set(session) + setDefaultSession(session) + setActiveSession(session) // Register a successfully instantiated context to the singleton. This should be at the // end of the class definition so that the singleton is updated only if there is no @@ -1027,6 +1028,17 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: SparkSession = { + getActiveSession.getOrElse(getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index c0301f2ce2d66..44bf8624a6bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(SparkSession.builder().getOrCreate() == session) } + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[IllegalStateException](SparkSession.active) + session.stop() + } + test("config options are propagated to existing SparkSession") { val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 3038b822beb4a..17603deacdcdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,6 +35,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) } SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) @transient override lazy val sessionState: SessionState = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 814038d4ef7af..a7006a16d7b73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -179,6 +179,9 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", From 5cfd5fabcdbd77a806b98a6dd59b02772d2f6dee Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 3 Apr 2018 17:25:54 -0700 Subject: [PATCH 248/593] [SPARK-23802][SQL] PropagateEmptyRelation can leave query plan in unresolved state ## What changes were proposed in this pull request? Add cast to nulls introduced by PropagateEmptyRelation so in cases they're part of coalesce they will not break its type checking rules ## How was this patch tested? Added unit test Author: Robert Kruszewski Closes #20914 from robert3005/rk/propagate-empty-fix. --- .../optimizer/PropagateEmptyRelation.scala | 8 ++++-- .../PropagateEmptyRelationSuite.scala | 26 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index a6e5aa6daca65..c3fdb924243df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Collapse plans consisting empty local relations generated by [[PruneFilters]]. @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ -object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { +object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport { private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = - plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + + override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 3964508e3a55e..f1ce7543ffdc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceIntersectWithSemiJoin, PushDownPredicate, PruneFilters, - PropagateEmptyRelation) :: Nil + PropagateEmptyRelation, + CollapseProject) :: Nil } object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters, + CollapseProject) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) @@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, FullOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), (true, false, LeftSemi, Some(LocalRelation('a.int))), @@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, - Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("propagate empty relation keeps the plan resolved") { + val query = testRelation1.join( + LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + val optimized = Optimize.execute(query.analyze) + assert(optimized.resolved) + } } From 16ef6baa36ac11c72cfeafaa2363e6b69f0ba573 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 4 Apr 2018 14:31:03 +0800 Subject: [PATCH 249/593] [SPARK-23826][TEST] TestHiveSparkSession should set default session ## What changes were proposed in this pull request? In TestHive, the base spark session does this in getOrCreate(), we emulate that behavior for tests. ## How was this patch tested? N/A Author: gatorsmile Closes #20969 from gatorsmile/setDefault. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a7006a16d7b73..965aea2b61456 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,10 +159,6 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => - // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as - // TestSparkSession, but doing this the same way breaks many tests in the package. We need - // to investigate and find a different strategy. - def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From 5197562afe8534b29f5a0d72683c2859f796275d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Apr 2018 14:39:19 +0800 Subject: [PATCH 250/593] [SPARK-21351][SQL] Update nullability based on children's output ## What changes were proposed in this pull request? This pr added a new optimizer rule `UpdateNullabilityInAttributeReferences ` to update the nullability that `Filter` changes when having `IsNotNull`. In the master, optimized plans do not respect the nullability when `Filter` has `IsNotNull`. This wrongly generates unnecessary code. For example: ``` scala> val df = Seq((Some(1), Some(2))).toDF("a", "b") scala> val bIsNotNull = df.where($"b" =!= 2).select($"b") scala> val targetQuery = bIsNotNull.distinct scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = true scala> targetQuery.debugCodegen Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- Exchange hashpartitioning(b#19, 200) +- *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- *Project [_2#16 AS b#19] +- *Filter isnotnull(_2#16) +- LocalTableScan [_1#15, _2#16] Generated code: ... /* 124 */ protected void processNext() throws java.io.IOException { ... /* 132 */ // output the result /* 133 */ /* 134 */ while (agg_mapIter.next()) { /* 135 */ wholestagecodegen_numOutputRows.add(1); /* 136 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 137 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 138 */ /* 139 */ boolean agg_isNull4 = agg_aggKey.isNullAt(0); /* 140 */ int agg_value4 = agg_isNull4 ? -1 : (agg_aggKey.getInt(0)); /* 141 */ agg_rowWriter1.zeroOutNullBytes(); /* 142 */ // We don't need this NULL check because NULL is filtered out in `$"b" =!=2` /* 143 */ if (agg_isNull4) { /* 144 */ agg_rowWriter1.setNullAt(0); /* 145 */ } else { /* 146 */ agg_rowWriter1.write(0, agg_value4); /* 147 */ } /* 148 */ append(agg_result1); /* 149 */ /* 150 */ if (shouldStop()) return; /* 151 */ } /* 152 */ /* 153 */ agg_mapIter.close(); /* 154 */ if (agg_sorter == null) { /* 155 */ agg_hashMap.free(); /* 156 */ } /* 157 */ } /* 158 */ /* 159 */ } ``` In the line 143, we don't need this NULL check because NULL is filtered out in `$"b" =!=2`. This pr could remove this NULL check; ``` scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = false scala> targetQuery.debugCodegen ... Generated code: ... /* 144 */ protected void processNext() throws java.io.IOException { ... /* 152 */ // output the result /* 153 */ /* 154 */ while (agg_mapIter.next()) { /* 155 */ wholestagecodegen_numOutputRows.add(1); /* 156 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 157 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 158 */ /* 159 */ int agg_value4 = agg_aggKey.getInt(0); /* 160 */ agg_rowWriter1.write(0, agg_value4); /* 161 */ append(agg_result1); /* 162 */ /* 163 */ if (shouldStop()) return; /* 164 */ } /* 165 */ /* 166 */ agg_mapIter.close(); /* 167 */ if (agg_sorter == null) { /* 168 */ agg_hashMap.free(); /* 169 */ } /* 170 */ } ``` ## How was this patch tested? Added `UpdateNullabilityInAttributeReferencesSuite` for unit tests. Author: Takeshi Yamamuro Closes #18576 from maropu/SPARK-21351. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 ++++++- ...ullabilityInAttributeReferencesSuite.scala | 57 +++++++++++++++++++ .../optimizer/complexTypesSuite.scala | 9 --- .../org/apache/spark/sql/DataFrameSuite.scala | 5 -- 4 files changed, 75 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2829d1d81eb1a..9a1bbc675e397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 0000000000000..09b11f5aba2a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 21ed987627b3b..633d86d495581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .groupBy($"foo")("1") checkRule(structRel, structExpected) - // These tests must use nullable attributes from the base relation for the following reason: - // in the 'original' plans below, the Aggregate node produced by groupBy() has a - // nullable AttributeReference to a1, because both array indexing and map lookup are - // nullable expressions. After optimization, the same attribute is now non-nullable, - // but the AttributeReference is not updated to reflect this. In the 'expected' plans, - // the grouping expressions have the same nullability as the original attribute in the - // relation. If that attribute is non-nullable, the tests will fail as the plans will - // compare differently, so for these tests we must use a nullable attribute. See - // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") .groupBy($"a1")("1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f7b3393f65cb1..60e84e6ee7504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2055,11 +2055,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; From a35523653cdac039ee2ddff316bc2c25d6514a91 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Apr 2018 18:36:15 +0200 Subject: [PATCH 251/593] [SPARK-23583][SQL] Invoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `Invoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20797 from kiszk/SPARK-28583. --- .../spark/sql/catalyst/ScalaReflection.scala | 48 +++++++++++++- .../expressions/objects/objects.scala | 56 ++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 65 +++++++++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9a4bf0075a178..1aae3aea3a31a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection { "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while") + + val typeJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[DateType.InternalType], + TimestampType -> classOf[TimestampType.InternalType], + BinaryType -> classOf[BinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long] + ) + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e9d357c19c63..a455c1c821a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.objects -import java.lang.reflect.Modifier +import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ import scala.collection.mutable.Builder @@ -28,7 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression { (argCode, argValues.mkString(", "), resultIsNull) } + + /** + * Evaluate each argument with a given row, invoke a method with a given object and arguments, + * and cast a return value if the return type can be mapped to a Java Boxed type + * + * @param obj the object for the method to be called. If null, perform s static method call + * @param method the method object to be called + * @param arguments the arguments used for the method call + * @param input the row used for evaluating arguments + * @param dataType the data type of the return object + * @return the return object of a method call + */ + def invoke( + obj: Any, + method: Method, + arguments: Seq[Expression], + input: InternalRow, + dataType: DataType): Any = { + val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) + if (needNullCheck && args.exists(_ == null)) { + // return null if one of arguments is null + null + } else { + val ret = method.invoke(obj, args: _*) + val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType) + if (boxedClass.isDefined) { + boxedClass.get.cast(ret) + } else { + ret + } + } + } } /** @@ -264,12 +296,11 @@ case class Invoke( propagateNull: Boolean = true, returnNullable : Boolean = true) extends InvokeLike { + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - private lazy val encodedFunctionName = TermName(functionName).encodedName.toString @transient lazy val method = targetObject.dataType match { @@ -283,6 +314,21 @@ case class Invoke( case _ => None } + override def eval(input: InternalRow): Any = { + val obj = targetObject.eval(input) + if (obj == null) { + // return null if obj is null + null + } else { + val invokeMethod = if (method.isDefined) { + method.get + } else { + obj.getClass.getDeclaredMethod(functionName, argClasses: _*) + } + invoke(obj, invokeMethod, arguments, input, dataType) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0edd27c8241e8..9bfe2916b0820 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +class InvokeTargetClass extends Serializable { + def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 + def filterPrimitiveInt(e: Int): Boolean = e > 0 + def binOp(e1: Int, e2: Double): Double = e1 + e2 +} + +class InvokeTargetSubClass extends InvokeTargetClass { + override def binOp(e1: Int, e2: Double): Double = e1 - e2 +} class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23583: Invoke should support interpreted execution") { + val targetObject = new InvokeTargetClass + val funcClass = classOf[InvokeTargetClass] + val funcObj = Literal.create(targetObject, ObjectType(funcClass)) + val targetSubObject = new InvokeTargetSubClass + val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass])) + val funcNullObj = Literal.create(null, ObjectType(funcClass)) + + val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true)) + val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false)) + val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false)) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt), + false, InternalRow.fromSeq(Seq(-1))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(null))) + + checkObjectExprEvaluation( + Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25)) + + checkObjectExprEvaluation( + Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } + // by scala values instead of catalyst values. + private def checkObjectExprEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + checkEvaluationWithoutCodegen(expr, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvaluationWithUnsafeProjection( + expr, + expected, + inputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + } + checkEvaluationWithOptimization(expr, expected, inputRow) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From cccaaa14ad775fb981e501452ba2cc06ff5c0f0a Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Wed, 4 Apr 2018 12:30:52 -0700 Subject: [PATCH 252/593] [SPARK-23668][K8S] Add config option for passing through k8s Pod.spec.imagePullSecrets ## What changes were proposed in this pull request? Pass through the `imagePullSecrets` option to the k8s pod in order to allow user to access private image registries. See https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ ## How was this patch tested? Unit tests + manual testing. Manual testing procedure: 1. Have private image registry. 2. Spark-submit application with no `spark.kubernetes.imagePullSecret` set. Do `kubectl describe pod ...`. See the error message: ``` Error syncing pod, skipping: failed to "StartContainer" for "spark-kubernetes-driver" with ErrImagePull: "rpc error: code = 2 desc = Error: Status 400 trying to pull repository ...: \"{\\n \\\"errors\\\" : [ {\\n \\\"status\\\" : 400,\\n \\\"message\\\" : \\\"Unsupported docker v1 repository request for '...'\\\"\\n } ]\\n}\"" ``` 3. Create secret `kubectl create secret docker-registry ...` 4. Spark-submit with `spark.kubernetes.imagePullSecret` set to the new secret. See that deployment was successful. Author: Andrew Korzhuev Author: Andrew Korzhuev Closes #20811 from andrusha/spark-23668-image-pull-secrets. --- .../org/apache/spark/deploy/k8s/Config.scala | 7 ++++ .../spark/deploy/k8s/KubernetesUtils.scala | 13 +++++++ .../steps/BasicDriverConfigurationStep.scala | 7 +++- .../cluster/k8s/ExecutorPodFactory.scala | 4 +++ .../deploy/k8s/KubernetesUtilsTest.scala | 36 +++++++++++++++++++ .../BasicDriverConfigurationStepSuite.scala | 8 ++++- .../cluster/k8s/ExecutorPodFactorySuite.scala | 5 +++ 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 405ea476351bb..82f6c714f3555 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -54,6 +54,13 @@ private[spark] object Config extends Logging { .checkValues(Set("Always", "Never", "IfNotPresent")) .createWithDefault("IfNotPresent") + val IMAGE_PULL_SECRETS = + ConfigBuilder("spark.kubernetes.container.image.pullSecrets") + .doc("Comma separated list of the Kubernetes secrets used " + + "to access private image registries.") + .stringConf + .createOptional + val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5bc070147d3a8..5b2bb819cdb14 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.deploy.k8s +import io.fabric8.kubernetes.api.model.LocalObjectReference + import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -35,6 +37,17 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } + /** + * Parses comma-separated list of imagePullSecrets into K8s-understandable format + */ + def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { + imagePullSecrets match { + case Some(secretsCommaSeparated) => + secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList + case None => Nil + } + } + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b811db324108c..fcb1db8008053 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit.steps import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ @@ -51,6 +51,8 @@ private[spark] class BasicDriverConfigurationStep( .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) + // CPU settings private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) @@ -129,6 +131,8 @@ private[spark] class BasicDriverConfigurationStep( case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() } + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() .withName(driverPodName) @@ -138,6 +142,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewSpec() .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 7143f7a6f0b71..8607d6fba3234 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -68,6 +68,7 @@ private[spark] class ExecutorPodFactory( .get(EXECUTOR_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the executor container image")) private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) private val blockManagerPort = sparkConf .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) @@ -103,6 +104,8 @@ private[spark] class ExecutorPodFactory( nodeToLocalTaskCount: Map[String, Int]): Pod = { val name = s"$executorPodNamePrefix-exec-$executorId" + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod // name as the hostname. This preserves uniqueness since the end of name contains // executorId @@ -194,6 +197,7 @@ private[spark] class ExecutorPodFactory( .withHostname(hostname) .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala new file mode 100644 index 0000000000000..cf41b22e241af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.LocalObjectReference + +import org.apache.spark.SparkFunSuite + +class KubernetesUtilsTest extends SparkFunSuite { + + test("testParseImagePullSecrets") { + val noSecrets = KubernetesUtils.parseImagePullSecrets(None) + assert(noSecrets === Nil) + + val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) + assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) + + val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) + assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e59c6d28a8cc2..ee450fff8d376 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -51,6 +51,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") + .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") val submissionStep = new BasicDriverConfigurationStep( APP_ID, @@ -103,7 +104,12 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, SPARK_APP_NAME_ANNOTATION -> APP_NAME) assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never") + + val driverPodSpec = preparedDriverSpec.driverPod.getSpec + assert(driverPodSpec.getRestartPolicy === "Never") + assert(driverPodSpec.getImagePullSecrets.size() === 2) + assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a71a2a1b888bc..d73df20f0f956 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -33,6 +33,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" private val executorImage: String = "executor-image" + private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" private val driverPod = new PodBuilder() .withNewMetadata() .withName(driverPodName) @@ -54,6 +55,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set(IMAGE_PULL_SECRETS, imagePullSecrets) } test("basic executor pod has reasonable defaults") { @@ -76,6 +78,9 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") + assert(executor.getSpec.getImagePullSecrets.size() === 2) + assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") // The pod has no node selector, volumes. assert(executor.getSpec.getNodeSelector.isEmpty) From d8379e5bc3629f4e8233ad42831bdaf68c24cfeb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 4 Apr 2018 15:43:58 -0700 Subject: [PATCH 253/593] [SPARK-23838][WEBUI] Running SQL query is displayed as "completed" in SQL tab ## What changes were proposed in this pull request? A running SQL query would appear as completed in the Spark UI: ![image1](https://user-images.githubusercontent.com/1097932/38170733-3d7cb00c-35bf-11e8-994c-43f2d4fa285d.png) We can see the query in "Completed queries", while in in the job page we see it's still running Job 132. ![image2](https://user-images.githubusercontent.com/1097932/38170735-48f2c714-35bf-11e8-8a41-6fae23543c46.png) After some time in the query still appears in "Completed queries" (while it's still running), but the "Duration" gets increased. ![image3](https://user-images.githubusercontent.com/1097932/38170737-50f87ea4-35bf-11e8-8b60-000f6f918964.png) To reproduce, we can run a query with multiple jobs. E.g. Run TPCDS q6. The reason is that updates from executions are written into kvstore periodically, and the job start event may be missed. ## How was this patch tested? Manually run the job again and check the SQL Tab. The fix is pretty simple. Author: Gengliang Wang Closes #20955 from gengliangwang/jobCompleted. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 3 ++- .../spark/sql/execution/ui/SQLAppStatusListener.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index e751ce39cd5d7..582528777f90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -39,7 +39,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() sqlStore.executionsList().foreach { e => - val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isRunning = e.completionTime.isEmpty || + e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } if (isRunning) { running += e diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 71e9f93c4566e..2b6bb48467eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -88,7 +88,7 @@ class SQLAppStatusListener( exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) exec.stages ++= event.stageIds.toSet - update(exec) + update(exec, force = true) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -308,11 +308,13 @@ class SQLAppStatusListener( }) } - private def update(exec: LiveExecutionData): Unit = { + private def update(exec: LiveExecutionData, force: Boolean = false): Unit = { val now = System.nanoTime() if (exec.endEvents >= exec.jobs.size + 1) { exec.write(kvstore, now) liveExecutions.remove(exec.executionId) + } else if (force) { + exec.write(kvstore, now) } else if (liveUpdatePeriodNs >= 0) { if (now - exec.lastWriteTime > liveUpdatePeriodNs) { exec.write(kvstore, now) From d3bd0435ee4ff3d414f32cce3f58b6b9f67e68bc Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 4 Apr 2018 15:51:27 -0700 Subject: [PATCH 254/593] [SPARK-23637][YARN] Yarn might allocate more resource if a same executor is killed multiple times. ## What changes were proposed in this pull request? `YarnAllocator` uses `numExecutorsRunning` to track the number of running executor. `numExecutorsRunning` is used to check if there're executors missing and need to allocate more. In current code, `numExecutorsRunning` can be negative when driver asks to kill a same idle executor multiple times. ## How was this patch tested? UT added Author: jinxing Closes #20781 from jinxing64/SPARK-23637. --- .../spark/deploy/yarn/YarnAllocator.scala | 36 +++++++------- .../deploy/yarn/YarnAllocatorSuite.scala | 48 ++++++++++++++++++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a537243d641cb..ebee3d431744d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -81,7 +81,8 @@ private[yarn] class YarnAllocator( private val releasedContainers = Collections.newSetFromMap[ContainerId]( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - private val numExecutorsRunning = new AtomicInteger(0) + private val runningExecutors = Collections.newSetFromMap[String]( + new ConcurrentHashMap[String, java.lang.Boolean]()) private val numExecutorsStarting = new AtomicInteger(0) @@ -166,7 +167,7 @@ private[yarn] class YarnAllocator( clock = newClock } - def getNumExecutorsRunning: Int = numExecutorsRunning.get() + def getNumExecutorsRunning: Int = runningExecutors.size() def getNumExecutorsFailed: Int = synchronized { val endTime = clock.getTimeMillis() @@ -242,12 +243,11 @@ private[yarn] class YarnAllocator( * Request that the ResourceManager release the container running the specified executor. */ def killExecutor(executorId: String): Unit = synchronized { - if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.get(executorId).get - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - logWarning(s"Attempted to kill unknown executor $executorId!") + executorIdToContainer.get(executorId) match { + case Some(container) if !releasedContainers.contains(container.getId) => + internalReleaseContainer(container) + runningExecutors.remove(executorId) + case _ => logWarning(s"Attempted to kill unknown executor $executorId!") } } @@ -274,7 +274,7 @@ private[yarn] class YarnAllocator( "Launching executor count: %d. Cluster resources: %s.") .format( allocatedContainers.size, - numExecutorsRunning.get, + runningExecutors.size, numExecutorsStarting.get, allocateResponse.getAvailableResources)) @@ -286,7 +286,7 @@ private[yarn] class YarnAllocator( logDebug("Completed %d containers".format(completedContainers.size)) processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." - .format(completedContainers.size, numExecutorsRunning.get)) + .format(completedContainers.size, runningExecutors.size)) } } @@ -300,9 +300,9 @@ private[yarn] class YarnAllocator( val pendingAllocate = getPendingAllocate val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - - numExecutorsStarting.get - numExecutorsRunning.get + numExecutorsStarting.get - runningExecutors.size logDebug(s"Updating resource requests, target: $targetNumExecutors, " + - s"pending: $numPendingAllocate, running: ${numExecutorsRunning.get}, " + + s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") if (missing > 0) { @@ -502,7 +502,7 @@ private[yarn] class YarnAllocator( s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { - numExecutorsRunning.incrementAndGet() + runningExecutors.add(executorId) numExecutorsStarting.decrementAndGet() executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -513,7 +513,7 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (numExecutorsRunning.get < targetNumExecutors) { + if (runningExecutors.size() < targetNumExecutors) { numExecutorsStarting.incrementAndGet() if (launchContainers) { launcherPool.execute(new Runnable { @@ -554,7 +554,7 @@ private[yarn] class YarnAllocator( } else { logInfo(("Skip launching executorRunnable as running executors count: %d " + "reached target executors count: %d.").format( - numExecutorsRunning.get, targetNumExecutors)) + runningExecutors.size, targetNumExecutors)) } } } @@ -569,7 +569,11 @@ private[yarn] class YarnAllocator( val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() + containerIdToExecutorId.get(containerId) match { + case Some(executorId) => runningExecutors.remove(executorId) + case None => logWarning(s"Cannot find executorId for container: ${containerId.toString}") + } + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, onHostStr, diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index cb1e3c5268510..525abb6f2b350 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -251,11 +251,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (1) } + test("kill same executor multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val executorToKill = handler.executorIdToContainer.keys.head + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty) + handler.updateResourceRequests() + handler.getPendingAllocate.size should be (1) + } + + test("process same completed container multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val statuses = Seq(container1, container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.processCompletedContainers(statuses) + handler.getNumExecutorsRunning should be (0) + + } + test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() @@ -272,7 +316,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (2) From c5c8b544047a83cb6128a20d31f1d943a15f9260 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 13:39:45 +0200 Subject: [PATCH 255/593] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20756 from viirya/SPARK-23593. --- .../expressions/objects/objects.scala | 47 ++++++++++++++++- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 +++++++++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a455c1c821a26..20c4f4c7324fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1410,8 +1410,47 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + if (paramVal == null) { + throw new NullPointerException("The parameter value for setters in " + + "`InitializeJavaBean` can not be null") + } else { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1424,6 +1463,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} + |if (${fieldGen.isNull}) { + | throw new NullPointerException("The parameter value for setters in " + + | "`InitializeJavaBean` can not be null"); + |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 9bfe2916b0820..44fecd602e854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -128,6 +128,50 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("Can not pass in null into setters in InitializeJavaBean") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + intercept[NullPointerException] { + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + intercept[NullPointerException] { + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -278,3 +322,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 1822ecda51cc9e14bb18050e0b8c270fee47ced7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 5 Apr 2018 13:47:06 +0200 Subject: [PATCH 256/593] [SPARK-23582][SQL] StaticInvoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `StaticInvoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20753 from kiszk/SPARK-23582. --- .../expressions/objects/objects.scala | 14 +++- .../expressions/ObjectExpressionsSuite.scala | 66 ++++++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 20c4f4c7324fd..9ca0b6137679e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. @@ -217,12 +218,21 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + + override def eval(input: InternalRow): Any = { + invoke(null, method, arguments, input, dataType) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 44fecd602e854..eb89e01b5ff9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -28,9 +30,11 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -93,6 +97,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23582: StaticInvoke should support interpreted execution") { + Seq((classOf[java.lang.Boolean], "true", true), + (classOf[java.lang.Byte], "1", 1.toByte), + (classOf[java.lang.Short], "257", 257.toShort), + (classOf[java.lang.Integer], "12345", 12345), + (classOf[java.lang.Long], "12345678", 12345678.toLong), + (classOf[java.lang.Float], "12.34", 12.34.toFloat), + (classOf[java.lang.Double], "1.2345678", 1.2345678) + ).foreach { case (cls, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))), + expected, InternalRow.fromSeq(Seq(arg))) + } + + // Return null when null argument is passed with propagateNull = true + val stringCls = classOf[java.lang.String] + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true), + null, InternalRow.fromSeq(Seq(null))) + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false), + "null", InternalRow.fromSeq(Seq(null))) + + // test no argument + val clCls = classOf[java.lang.ClassLoader] + checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil), + ClassLoader.getSystemClassLoader, InternalRow.empty) + // test more than one argument + val intCls = classOf[java.lang.Integer] + checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare", + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))), + 0, InternalRow.fromSeq(Seq(7, 7))) + + Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]), + new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))), + (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]), + new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))), + (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]), + "abc", UTF8String.fromString("abc")), + (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]), + BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))), + (Decimal.getClass, DecimalType.SYSTEM_DEFAULT, + "apply", ObjectType(classOf[java.math.BigInteger]), + new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))), + (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]), + Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))), + (classOf[UnsafeArrayData], ArrayType(IntegerType, false), + "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), + Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), + (DateTimeUtils.getClass, ObjectType(classOf[Date]), + "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), + "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) + ).foreach { case (cls, dataType, methodName, argType, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, + Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) + } + } + test("SPARK-23583: Invoke should support interpreted execution") { val targetObject = new InvokeTargetClass val funcClass = classOf[InvokeTargetClass] From b2329fb1fcdc0e93c4bdc39d574cde7328ef6094 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 5 Apr 2018 13:57:41 +0200 Subject: [PATCH 257/593] Revert "[SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression" This reverts commit c5c8b544047a83cb6128a20d31f1d943a15f9260. --- .../expressions/objects/objects.scala | 47 +---------------- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 ------------------- 3 files changed, 5 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9ca0b6137679e..3fa91bd36bb60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,47 +1420,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - private lazy val resolvedSetters = { - assert(beanInstance.dataType.isInstanceOf[ObjectType]) - - val ObjectType(beanClass) = beanInstance.dataType - setters.map { - case (name, expr) => - // Looking for known type mapping. - // But also looking for general `Object`-type parameter for generic methods. - val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) - val methods = paramTypes.flatMap { fieldClass => - try { - Some(beanClass.getDeclaredMethod(name, fieldClass)) - } catch { - case e: NoSuchMethodException => None - } - } - if (methods.isEmpty) { - throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + - "in any enclosing class nor any supertype") - } - methods.head -> expr - } - } - - override def eval(input: InternalRow): Any = { - val instance = beanInstance.eval(input) - if (instance != null) { - val bean = instance.asInstanceOf[Object] - resolvedSetters.foreach { - case (setter, expr) => - val paramVal = expr.eval(input) - if (paramVal == null) { - throw new NullPointerException("The parameter value for setters in " + - "`InitializeJavaBean` can not be null") - } else { - setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) - } - } - } - instance - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1473,10 +1434,6 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |if (${fieldGen.isNull}) { - | throw new NullPointerException("The parameter value for setters in " + - | "`InitializeJavaBean` can not be null"); - |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,8 +55,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -112,14 +111,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (!errMsg.contains(expectedErrMsg)) { + if (errMsg != expectedErrMsg) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index eb89e01b5ff9d..1d59b20077fa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,50 +192,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } - test("SPARK-23593: InitializeJavaBean should support interpreted execution") { - val list = new java.util.LinkedList[Int]() - list.add(1) - - val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), - Map("add" -> Literal(1))) - checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) - - val initializeWithNonexistingMethod = InitializeJavaBean( - Literal.fromObject(new java.util.LinkedList[Int]), - Map("nonexisting" -> Literal(1))) - checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), - """A method named "nonexisting" is not declared in any enclosing class """ + - "nor any supertype") - - val initializeWithWrongParamType = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setX" -> Literal("1"))) - intercept[Exception] { - evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) - }.getMessage.contains( - """A method named "setX" is not declared in any enclosing class """ + - "nor any supertype") - } - - test("Can not pass in null into setters in InitializeJavaBean") { - val initializeBean = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal(null))) - intercept[NullPointerException] { - evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - intercept[NullPointerException] { - evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - - val initializeBean2 = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal("string"))) - evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) - evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) - } - test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -386,11 +342,3 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - -class TestBean extends Serializable { - private var x: Int = 0 - - def setX(i: Int): Unit = x = i - def setNonPrimitive(i: AnyRef): Unit = - assert(i != null, "this setter should not be called with null.") -} From d9ca1c906bd0571802f2297c36b407e660fcdb64 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 20:43:05 +0200 Subject: [PATCH 258/593] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20985 from viirya/SPARK-23593-2. --- .../expressions/objects/objects.scala | 45 +++++++++++++++-- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 48 +++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 3fa91bd36bb60..9252425f86473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,8 +1420,45 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + // We don't call setter if input value is null. + if (paramVal != null) { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1434,7 +1471,9 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |$javaBeanInstance.$setterMethod(${fieldGen.value}); + |if (!${fieldGen.isNull}) { + | $javaBeanInstance.$setterMethod(${fieldGen.value}); + |} """.stripMargin } val initializeCode = ctx.splitExpressionsWithCurrentInputs( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1d59b20077fa9..b1bc67dfac1b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,6 +192,46 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("InitializeJavaBean doesn't call setters if input in null") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -342,3 +382,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 4807d381bb113a5c61e6dad88202f23a8b6dd141 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:13:59 +0800 Subject: [PATCH 259/593] [SPARK-10399][CORE][SQL] Introduce multiple MemoryBlocks to choose several types of memory block ## What changes were proposed in this pull request? This PR allows us to use one of several types of `MemoryBlock`, such as byte array, int array, long array, or `java.nio.DirectByteBuffer`. To use `java.nio.DirectByteBuffer` allows to have off heap memory which is automatically deallocated by JVM. `MemoryBlock` class has primitive accessors like `Platform.getInt()`, `Platform.putint()`, or `Platform.copyMemory()`. This PR uses `MemoryBlock` for `OffHeapColumnVector`, `UTF8String`, and other places. This PR can improve performance of operations involving memory accesses (e.g. `UTF8String.trim`) by 1.8x. For now, this PR does not use `MemoryBlock` for `BufferHolder` based on cloud-fan's [suggestion](https://github.com/apache/spark/pull/11494#issuecomment-309694290). Since this PR is a successor of #11494, close #11494. Many codes were ported from #11494. Many efforts were put here. **I think this PR should credit to yzotov.** This PR can achieve **1.1-1.4x performance improvements** for operations in `UTF8String` or `Murmur3_x86_32`. Other operations are almost comparable performances. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 526 / 536 0.0 131399881.5 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 525 / 552 1022.6 1.0 1.0X substring 414 / 423 1298.0 0.8 1.3X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 474 / 488 0.0 118552232.0 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 476 / 480 1127.3 0.9 1.0X substring 287 / 291 1869.9 0.5 1.7X ``` Benchmark program ``` test("benchmark Murmur3_x86_32") { val length = 8192 * 32768 + 31 val seed = 42L val iters = 1 << 2 val random = new Random(seed) val arrays = Array.fill[MemoryBlock](numArrays) { val bytes = new Array[Byte](length) random.nextBytes(bytes) new ByteArrayMemoryBlock(bytes, Platform.BYTE_ARRAY_OFFSET, length) } val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays, minNumIters = 20) benchmark.addCase("HiveHasher") { _: Int => var sum = 0L for (_ <- 0L until iters) { sum += HiveHasher.hashUnsafeBytesBlock( arrays(i), Platform.BYTE_ARRAY_OFFSET, length) } } benchmark.run() } test("benchmark UTF8String") { val N = 512 * 1024 * 1024 val iters = 2 val benchmark = new Benchmark("UTF8String benchmark", N, minNumIters = 20) val str0 = new java.io.StringWriter() { { for (i <- 0 until N) { write(" ") } } }.toString val s0 = UTF8String.fromString(str0) benchmark.addCase("hashCode") { _: Int => var h: Int = 0 for (_ <- 0L until iters) { h += s0.hashCode } } benchmark.addCase("substring") { _: Int => var s: UTF8String = null for (_ <- 0L until iters) { s = s0.substring(N / 2 - 5, N / 2 + 5) } } benchmark.run() } ``` I run [this benchmark program](https://gist.github.com/kiszk/94f75b506c93a663bbbc372ffe8f05de) using [the commit](https://github.com/apache/spark/pull/19222/commits/ee5a79861c18725fb1cd9b518cdfd2489c05b81d6). I got the following results: ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Memory access benchmarks: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ByteArrayMemoryBlock get/putInt() 220 / 221 609.3 1.6 1.0X Platform get/putInt(byte[]) 220 / 236 610.9 1.6 1.0X Platform get/putInt(Object) 492 / 494 272.8 3.7 0.4X OnHeapMemoryBlock get/putLong() 322 / 323 416.5 2.4 0.7X long[] 221 / 221 608.0 1.6 1.0X Platform get/putLong(long[]) 321 / 321 418.7 2.4 0.7X Platform get/putLong(Object) 561 / 563 239.2 4.2 0.4X ``` I also run [this benchmark program](https://gist.github.com/kiszk/5fdb4e03733a5d110421177e289d1fb5) for comparing performance of `Platform.copyMemory()`. ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Platform copyMemory: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Object to Object 1961 / 1967 8.6 116.9 1.0X System.arraycopy Object to Object 1917 / 1921 8.8 114.3 1.0X byte array to byte array 1961 / 1968 8.6 116.9 1.0X System.arraycopy byte array to byte array 1909 / 1937 8.8 113.8 1.0X int array to int array 1921 / 1990 8.7 114.5 1.0X double array to double array 1918 / 1923 8.7 114.3 1.0X Object to byte array 1961 / 1967 8.6 116.9 1.0X Object to short array 1965 / 1972 8.5 117.1 1.0X Object to int array 1910 / 1915 8.8 113.9 1.0X Object to float array 1971 / 1978 8.5 117.5 1.0X Object to double array 1919 / 1944 8.7 114.4 1.0X byte array to Object 1959 / 1967 8.6 116.8 1.0X int array to Object 1961 / 1970 8.6 116.9 1.0X double array to Object 1917 / 1924 8.8 114.3 1.0X ``` These results show three facts: 1. According to the second/third or sixth/seventh results in the first experiment, if we use `Platform.get/putInt(Object)`, we achieve more than 2x worse performance than `Platform.get/putInt(byte[])` with concrete type (i.e. `byte[]`). 2. According to the second/third or fourth/fifth/sixth results in the first experiment, the fastest way to access an array element on Java heap is `array[]`. **Cons of `array[]` is that it is not possible to support unaligned-8byte access.** 3. According to the first/second/third or fourth/sixth/seventh results in the first experiment, `getInt()/putInt() or getLong()/putLong()` in subclasses of `MemoryBlock` can achieve comparable performance to `Platform.get/putInt()` or `Platform.get/putLong()` with concrete type (second or sixth result). There is no overhead regarding virtual call. 4. According to results in the second experiment, for `Platform.copy()`, to pass `Object` can achieve the same performance as to pass any type of primitive array as source or destination. 5. According to second/fourth results in the second experiment, `Platform.copy()` can achieve the same performance as `System.arrayCopy`. **It would be good to use `Platform.copy()` since `Platform.copy()` can take any types for src and dst.** We are incrementally replace `Platform.get/putXXX` with `MemoryBlock.get/putXXX`. This is because we have two advantages. 1) Achieve better performance due to having a concrete type for an array. 2) Use simple OO design instead of passing `Object` It is easy to use `MemoryBlock` in `InternalRow`, `BufferHolder`, `TaskMemoryManager`, and others that are already abstracted. It is not easy to use `MemoryBlock` in utility classes related to hashing or others. Other candidates are - UnsafeRow, UnsafeArrayData, UnsafeMapData, SpecificUnsafeRowJoiner - UTF8StringBuffer - BufferHolder - TaskMemoryManager - OnHeapColumnVector - BytesToBytesMap - CachedBatch - classes for hash - others. ## How was this patch tested? Added `UnsafeMemoryAllocator` Author: Kazuaki Ishizaki Closes #19222 from kiszk/SPARK-10399. --- .../sql/catalyst/expressions/HiveHasher.java | 12 +- .../org/apache/spark/unsafe/Platform.java | 2 +- .../spark/unsafe/array/ByteArrayMethods.java | 13 +- .../apache/spark/unsafe/array/LongArray.java | 17 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 45 +++-- .../unsafe/memory/ByteArrayMemoryBlock.java | 128 +++++++++++++ .../unsafe/memory/HeapMemoryAllocator.java | 19 +- .../spark/unsafe/memory/MemoryAllocator.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 157 ++++++++++++++-- .../spark/unsafe/memory/MemoryLocation.java | 54 ------ .../unsafe/memory/OffHeapMemoryBlock.java | 105 +++++++++++ .../unsafe/memory/OnHeapMemoryBlock.java | 132 +++++++++++++ .../unsafe/memory/UnsafeMemoryAllocator.java | 21 ++- .../apache/spark/unsafe/types/UTF8String.java | 148 +++++++-------- .../spark/unsafe/PlatformUtilSuite.java | 4 +- .../spark/unsafe/array/LongArraySuite.java | 5 +- .../unsafe/hash/Murmur3_x86_32Suite.java | 18 ++ .../spark/unsafe/memory/MemoryBlockSuite.java | 175 ++++++++++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 29 +-- .../spark/memory/TaskMemoryManager.java | 22 +-- .../shuffle/sort/ShuffleInMemorySorter.java | 14 +- .../shuffle/sort/ShuffleSortDataFormat.java | 11 +- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../unsafe/sort/UnsafeInMemorySorter.java | 13 +- .../spark/memory/TaskMemoryManagerSuite.java | 2 +- .../util/collection/ExternalSorterSuite.scala | 7 +- .../unsafe/sort/RadixSortSuite.scala | 10 +- .../spark/ml/feature/FeatureHasher.scala | 5 +- .../spark/mllib/feature/HashingTF.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../spark/sql/catalyst/expressions/XXH64.java | 46 +++-- .../spark/sql/catalyst/expressions/hash.scala | 39 ++-- .../catalyst/expressions/HiveHasherSuite.java | 20 +- .../sql/catalyst/expressions/XXH64Suite.java | 18 +- .../vectorized/OffHeapColumnVector.java | 3 +- .../sql/vectorized/ArrowColumnVector.java | 6 +- .../execution/benchmark/SortBenchmark.scala | 16 +- .../sql/execution/python/RowQueueSuite.scala | 4 +- 39 files changed, 1002 insertions(+), 334 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java create mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 73577437ac506..5d905943a3aa7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -38,12 +39,17 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + public static int hashUnsafeBytesBlock(MemoryBlock mb) { + long lengthInBytes = mb.size(); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (int i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) Platform.getByte(base, offset + i); + for (long i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) mb.getByte(i); } return result; } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..54dcadf3a7754 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index a6b1f7a16d605..c334c9651cf6b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -48,6 +49,16 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); + /** + * MemoryBlock equality check for MemoryBlocks. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEqualsBlock( + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, + rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); + } + /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -56,7 +67,7 @@ public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if stars align and we can get both offsets to be aligned + // check if starts align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 2cd39bd60c2ac..b74d2de0691d5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -33,16 +32,12 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; - private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); - this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -51,11 +46,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return baseObj; + return memory.getBaseObject(); } public long getBaseOffset() { - return baseOffset; + return memory.getBaseOffset(); } /** @@ -69,8 +64,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + for (long off = 0; off < length * WIDTH; off += WIDTH) { + memory.putLong(off, 0); } } @@ -80,7 +75,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + memory.putLong(index * WIDTH, value); } /** @@ -89,6 +84,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return memory.getLong(index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index d239de6083ad0..f372b19fac119 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,9 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.Platform; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.memory.MemoryBlock; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -49,49 +51,66 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + int h1 = hashBytesByIntBlock(base, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(base, offset + i); + int halfWord = base.getByte(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthInBytes = Ints.checkedCast(base.size()); + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + k1 ^= (base.getByte(i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + private static int hashBytesByIntBlock(MemoryBlock base, int seed) { + long lengthInBytes = base.size(); assert (lengthInBytes % 4 == 0); int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); + for (long i = 0; i < lengthInBytes; i += 4) { + int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 0000000000000..99a9868a49a79 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a byte array on Java heap. + */ +public final class ByteArrayMemoryBlock extends MemoryBlock { + + private final byte[] array; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); + } + + public ByteArrayMemoryBlock(long length) { + this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new ByteArrayMemoryBlock(array, this.offset + offset, size); + } + + public byte[] getByteArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; + } + + @Override + public final void putByte(long offset, byte value) { + array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 2733760dd19ef..acf28fd7ee59b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -58,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -70,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -79,12 +79,13 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj != null) : + assert(memory instanceof OnHeapMemoryBlock); + assert (memory.getBaseObject() != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -94,12 +95,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = (long[]) memory.obj; - memory.setObjAndOffset(null, 0); + long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); + memory.resetObjAndOffset(); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 7b588681d9790..38315fb97b46a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index c333857358d30..b086941108522 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + * A representation of a consecutive memory block in Spark. It defines the common interfaces + * for memory accessing and mutating. */ -public class MemoryBlock extends MemoryLocation { - +public abstract class MemoryBlock { /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,38 +45,163 @@ public class MemoryBlock extends MemoryLocation { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - private final long length; + @Nullable + protected Object obj; + + protected long offset; + + protected long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. + * TaskMemoryManager. This field can be updated using setPageNumber method so that + * this can be modified by the TaskMemoryManager, which lives in a different package. */ - public int pageNumber = NO_PAGE_NUMBER; + private int pageNumber = NO_PAGE_NUMBER; - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); + protected MemoryBlock(@Nullable Object obj, long offset, long length) { + if (offset < 0 || length < 0) { + throw new IllegalArgumentException( + "Length " + length + " and offset " + offset + "must be non-negative"); + } + this.obj = obj; + this.offset = offset; this.length = length; } + protected MemoryBlock() { + this(null, 0, 0); + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } + + public void resetObjAndOffset() { + this.obj = null; + this.offset = 0; + } + /** * Returns the size of the memory block. */ - public long size() { + public final long size() { return length; } - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + public final void setPageNumber(int pageNum) { + pageNumber = pageNum; + } + + public final int getPageNumber() { + return pageNumber; } /** * Fills the memory block with the specified byte value. */ - public void fill(byte value) { + public final void fill(byte value) { Platform.setMemory(obj, offset, length, value); } + + /** + * Instantiate MemoryBlock for given object type with new offset + */ + public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + MemoryBlock mb = null; + if (obj instanceof byte[]) { + byte[] array = (byte[])obj; + mb = new ByteArrayMemoryBlock(array, offset, length); + } else if (obj instanceof long[]) { + long[] array = (long[])obj; + mb = new OnHeapMemoryBlock(array, offset, length); + } else if (obj == null) { + // we assume that to pass null pointer means off-heap + mb = new OffHeapMemoryBlock(offset, length); + } else { + throw new UnsupportedOperationException( + "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); + } + return mb; + } + + /** + * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative + * offset from the original offset. The data is not copied. + * If parameters are invalid, an exception is thrown. + */ + public abstract MemoryBlock subBlock(long offset, long size); + + protected void checkSubBlockRange(long offset, long size) { + if (offset < 0 || size < 0) { + throw new ArrayIndexOutOfBoundsException( + "Size " + size + " and offset " + offset + " must be non-negative"); + } + if (offset + size > length) { + throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + + offset + " should not be larger than the length " + length + " in the MemoryBlock"); + } + } + + /** + * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal + * memory access, throw an exception, or etc. + * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as + * JVM object header. The offset is 0-based and is expected as an logical offset in the memory + * block. + */ + public abstract int getInt(long offset); + + public abstract void putInt(long offset, int value); + + public abstract boolean getBoolean(long offset); + + public abstract void putBoolean(long offset, boolean value); + + public abstract byte getByte(long offset); + + public abstract void putByte(long offset, byte value); + + public abstract short getShort(long offset); + + public abstract void putShort(long offset, short value); + + public abstract long getLong(long offset); + + public abstract void putLong(long offset, long value); + + public abstract float getFloat(long offset); + + public abstract void putFloat(long offset, float value); + + public abstract double getDouble(long offset); + + public abstract void putDouble(long offset, double value); + + public static final void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, + dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); + } + + public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { + assert(length <= src.length && length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), + dst.getBaseObject(), dst.getBaseOffset(), length); + } + + public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); + } + + public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java deleted file mode 100644 index 74ebc87dc978c..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import javax.annotation.Nullable; - -/** - * A memory location. Tracked either by a memory address (with off-heap allocation), - * or by an offset from a JVM object (in-heap allocation). - */ -public class MemoryLocation { - - @Nullable - Object obj; - - long offset; - - public MemoryLocation(@Nullable Object obj, long offset) { - this.obj = obj; - this.offset = offset; - } - - public MemoryLocation() { - this(null, 0); - } - - public void setObjAndOffset(Object newObj, long newOffset) { - this.obj = newObj; - this.offset = newOffset; - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 0000000000000..f90f62bf21dcb --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock extends MemoryBlock { + static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + + public OffHeapMemoryBlock(long address, long size) { + super(null, address, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OffHeapMemoryBlock(this.offset + offset, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(null, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(null, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(null, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(null, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(null, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(null, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(null, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(null, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(null, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(null, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(null, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(null, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(null, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(null, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java new file mode 100644 index 0000000000000..12f67c7bd593e --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a long array on Java heap. + */ +public final class OnHeapMemoryBlock extends MemoryBlock { + + private final long[] array; + + public OnHeapMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); + } + + public OnHeapMemoryBlock(long size) { + this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OnHeapMemoryBlock(array, this.offset + offset, size); + } + + public long[] getLongArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static OnHeapMemoryBlock fromArray(final long[] array) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + public static OnHeapMemoryBlock fromArray(final long[] array, long size) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(array, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(array, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 4368fb615ba1e..5310bdf2779a9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - MemoryBlock memory = new MemoryBlock(null, address, size); + OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,22 +36,25 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert(memory instanceof OffHeapMemoryBlock) : + "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; + if (memory == OffHeapMemoryBlock.NULL) return; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.offset = 0; + memory.resetObjAndOffset(); // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5d468aed42337..e9b3d9b045af5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,9 +30,12 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -50,12 +53,13 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; - private long offset; + private MemoryBlock base; + // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int private int numBytes; - public Object getBaseObject() { return base; } - public long getBaseOffset() { return offset; } + public MemoryBlock getMemoryBlock() { return base; } + public Object getBaseObject() { return base.getBaseObject(); } + public long getBaseOffset() { return base.getBaseOffset(); } /** * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which @@ -108,7 +112,8 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); } else { return null; } @@ -121,19 +126,13 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); } else { return null; } } - /** - * Creates an UTF8String from given address (base and offset) and length. - */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { - return new UTF8String(base, offset, numBytes); - } - /** * Creates an UTF8String from String. */ @@ -150,16 +149,13 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + public UTF8String(MemoryBlock base) { this.base = base; - this.offset = offset; - this.numBytes = numBytes; + this.numBytes = Ints.checkedCast(base.size()); } // for serialization - public UTF8String() { - this(null, 0, 0); - } + public UTF8String() {} /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -167,7 +163,7 @@ public UTF8String() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - Platform.copyMemory(base, offset, target, targetOffset, numBytes); + base.writeTo(0, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -187,8 +183,9 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + long offset = base.getBaseOffset(); + if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -255,12 +252,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); + p = (long) base.getInt(0); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -269,12 +266,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; + p = ((long) base.getInt(0)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -289,12 +286,13 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + long offset = base.getBaseOffset(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -324,7 +322,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -365,14 +363,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return Platform.getByte(base, offset + i); + return base.getByte(i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -499,8 +497,7 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } lastComma = i; @@ -508,8 +505,7 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } return 0; @@ -524,7 +520,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -671,8 +667,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - copyMemory(this.base, this.offset + i, result, - BYTE_ARRAY_OFFSET + result.length - i - len, len); + base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -686,7 +681,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -723,7 +718,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -739,7 +734,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start += 1; @@ -753,7 +748,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start -= 1; @@ -786,7 +781,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -806,7 +801,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -829,15 +824,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -865,13 +860,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -896,8 +891,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -936,8 +931,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -945,8 +940,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - copyMemory( - separator.base, separator.offset, + separator.base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1220,7 +1215,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1228,11 +1223,10 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; + MemoryBlock rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); + long left = base.getLong(i); + long right = rbase.getLong(i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1243,7 +1237,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); + int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); if (res != 0) { return res; } @@ -1262,7 +1256,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); } else { return false; } @@ -1318,8 +1312,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, - s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1334,7 +1328,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); + return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); } /** @@ -1397,10 +1391,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + byte[] bytes = new byte[numBytes]; + in.readFully(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } @Override @@ -1412,10 +1406,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - this.offset = BYTE_ARRAY_OFFSET; - this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + numBytes = in.readInt(); + byte[] bytes = new byte[numBytes]; + in.read(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..583a148b3845d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f3..8c2e98c2bfc54 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,14 +20,13 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; public class LongArraySuite { @Test public void basicTest() { - long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 6348a73bf3895..d7ed005db1891 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,6 +70,24 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } + @Test + public void testKnownWordsInputs() { + byte[] bytes = new byte[16]; + long offset = Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < 16; i++) { + bytes[i] = 0; + } + Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = -1; + } + Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = (byte)i; + } + Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java new file mode 100644 index 0000000000000..47f05c928f2e5 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteOrder; + +import static org.hamcrest.core.StringContains.containsString; + +public class MemoryBlockSuite { + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private void check(MemoryBlock memory, Object obj, long offset, int length) { + memory.setPageNumber(1); + memory.fill((byte)-1); + memory.putBoolean(0, true); + memory.putByte(1, (byte)127); + memory.putShort(2, (short)257); + memory.putInt(4, 0x20000002); + memory.putLong(8, 0x1234567089ABCDEFL); + memory.putFloat(16, 1.0F); + memory.putLong(20, 0x1234567089ABCDEFL); + memory.putDouble(28, 2.0); + MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); + int[] a = new int[2]; + a[0] = 0x12345678; + a[1] = 0x13579BDF; + memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); + byte[] b = new byte[8]; + memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); + + Assert.assertEquals(obj, memory.getBaseObject()); + Assert.assertEquals(offset, memory.getBaseOffset()); + Assert.assertEquals(length, memory.size()); + Assert.assertEquals(1, memory.getPageNumber()); + Assert.assertEquals(true, memory.getBoolean(0)); + Assert.assertEquals((byte)127, memory.getByte(1 )); + Assert.assertEquals((short)257, memory.getShort(2)); + Assert.assertEquals(0x20000002, memory.getInt(4)); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); + Assert.assertEquals(1.0F, memory.getFloat(16), 0); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); + Assert.assertEquals(2.0, memory.getDouble(28), 0); + Assert.assertEquals(true, memory.getBoolean(36)); + Assert.assertEquals((byte)127, memory.getByte(37 )); + Assert.assertEquals((short)257, memory.getShort(38)); + Assert.assertEquals(a[0], memory.getInt(40)); + Assert.assertEquals(a[1], memory.getInt(44)); + if (bigEndianPlatform) { + Assert.assertEquals(a[0], + ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | + ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | + ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); + } else { + Assert.assertEquals(a[0], + ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | + ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | + ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); + } + for (int i = 48; i < memory.size(); i++) { + Assert.assertEquals((byte) -1, memory.getByte(i)); + } + + assert(memory.subBlock(0, memory.size()) == memory); + + try { + memory.subBlock(-8, 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, -8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, length + 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(8, length - 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(length + 8, 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + } + + @Test + public void ByteArrayMemoryBlockTest() { + byte[] obj = new byte[56]; + long offset = Platform.BYTE_ARRAY_OFFSET; + int length = obj.length; + + MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = ByteArrayMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new byte[112]; + memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OnHeapMemoryBlockTest() { + long[] obj = new long[7]; + long offset = Platform.LONG_ARRAY_OFFSET; + int length = obj.length * 8; + + MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = OnHeapMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new long[14]; + memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OffHeapArrayMemoryBlockTest() { + MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); + MemoryBlock memory = memoryAllocator.allocate(56); + Object obj = memory.getBaseObject(); + long offset = memory.getBaseOffset(); + int length = 56; + + check(memory, obj, offset, length); + + long address = Platform.allocateMemory(112); + memory = new OffHeapMemoryBlock(address, length); + obj = memory.getBaseObject(); + offset = memory.getBaseOffset(); + check(memory, obj, offset, length); + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7c34d419574ef..bad908fcaf136 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; import static org.junit.Assert.*; @@ -519,7 +522,8 @@ public void writeToOutputStreamUnderflow() throws IOException { final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + new UTF8String( + new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); @@ -534,7 +538,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -565,7 +569,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -592,26 +596,25 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamIntArray() throws IOException { + public void writeToOutputStreamLongArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(12, length); + assertEquals(16, length); - final int ints = length / 4; - final int[] array = new int[ints]; + final int longs = length / 8; + final long[] array = new long[longs]; - for (int i = 0; i < ints; ++i) { - array[i] = buffer.getInt(); + for (int i = 0; i < longs; ++i) { + array[i] = buffer.getLong(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - fromAddress(array, Platform.INT_ARRAY_OFFSET, length) - .writeTo(outputStream); - assertEquals("大千世界", outputStream.toString("UTF-8")); + new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); + assertEquals("3千大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..8651a639c07f7 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911f..8f49859746b89 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,7 +20,6 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -105,13 +104,7 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -180,10 +173,7 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 717bdd79d47ef..254449e95443e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,13 +60,8 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - Platform.copyMemory( - src.getBaseObject(), - src.getBaseOffset() + srcPos * 8L, - dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8L, - length * 8L - ); + MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, + dst.memoryBlock(),dstPos * 8L,length * 8L); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 66118f454159b..4fc19b1721518 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != + if (!loaded || page.getPageNumber() != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b3c27d83da172..20a7a8b267438 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,7 +26,6 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -216,12 +215,7 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -348,10 +342,7 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index a0664b30d6cc2..d7d2d0b012bd3 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e2..3e56db5ea116a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(MemoryBlock.fromLongArray(ref)) - val tmp = new Array[Long](size/2) - val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) + val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..ddf3740e76a7a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index c78f61ac3ef71..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -243,8 +243,7 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - val utf8 = UTF8String.fromString(s) - hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 8935c8496cdbb..7b73b286fb91c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d18542b188f71..8546c28335536 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,6 +27,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -230,7 +231,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 71c086029cc5b..29a1411241cf6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -414,7 +415,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index f37ef83ad92b4..883748932ad33 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off /** @@ -71,13 +74,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWords(Object base, long offset, int length) { - return hashUnsafeWords(base, offset, length, seed); + public long hashUnsafeWordsBlock(MemoryBlock mb) { + return hashUnsafeWordsBlock(mb, seed); } - public static long hashUnsafeWords(Object base, long offset, int length, long seed) { - assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWords(base, offset, length, seed); + public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { + assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWordsBlock(mb, seed); return fmix(hash); } @@ -85,26 +88,32 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWords(base, offset, length, seed); + long hash = hashBytesByWordsBlock(mb, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } return fmix(hash); } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + } + private static long fmix(long hash) { hash ^= hash >>> 33; hash *= PRIME64_2; @@ -114,30 +123,31 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWords(Object base, long offset, int length, long seed) { - long end = offset + length; + private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); long hash; if (length >= 32) { - long limit = end - 32; + long limit = length - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 += mb.getLong(offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 += mb.getLong(offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 += mb.getLong(offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 += mb.getLong(offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -178,9 +188,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long hash += length; - long limit = end - 8; + long limit = length - 8; while (offset <= limit) { - long k1 = Platform.getLong(base, offset); + long k1 = mb.getLong(offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b702422ed7a1d..b76b64ab5096f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -360,10 +361,8 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" } protected def genHashForMap( @@ -465,6 +464,8 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long + /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -490,8 +491,7 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) case array: ArrayData => val elementType = dataType match { @@ -578,9 +578,15 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } + + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) + } } /** @@ -605,9 +611,14 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } + + override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { + XXH64.hashUnsafeBytesBlock(base, seed) + } } /** @@ -714,10 +725,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" } override protected def genHashForArray( @@ -805,10 +814,14 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) + private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index b67c6f3e6e85e..8ffc1d7c24d61 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -53,7 +55,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -89,13 +91,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. @@ -112,13 +114,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 1baee91b3439c..cd8bce623c5df 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,6 +24,8 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -142,13 +144,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. @@ -165,13 +167,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 754c26579ff08..4733f36174f42 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return UTF8String.fromAddress(null, data + rowId, count); + return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f8e37e995a17f..227a16f7e69e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -377,9 +378,10 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return UTF8String.fromAddress(null, + return new UTF8String(new OffHeapMemoryBlock( stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); + stringResult.end - stringResult.start + )); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d9..470b93efd1974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c5..25ee95daa034c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val page = new OnHeapMemoryBlock((1<<10) * 8L) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } From f2ac0879561cde63ed4eb759f5efa0a5ce393a22 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Thu, 5 Apr 2018 19:55:42 -0700 Subject: [PATCH 260/593] [SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns ## What changes were proposed in this pull request? `handleInvalid` Param was forwarded to the VectorAssembler used by RFormula. ## How was this patch tested? added a test and ran all tests for RFormula and VectorAssembler Author: Yogesh Garg Closes #20970 from yogeshg/spark_23562. --- .../apache/spark/ml/feature/RFormula.scala | 1 + .../spark/ml/feature/RFormulaSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 22e7b8bbf1ff5..e214765e3307f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + .setHandleInvalid($(handleInvalid)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 27d570f0b68ad..a250331efeb1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { assert(features.toArray === a +: b.toArray) } } + + test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") { + val d1 = Seq( + (1001L, "a"), + (1002L, "b")).toDF("id1", "c1") + val d2 = Seq[(java.lang.Long, String)]( + (20001L, "x"), + (20002L, "y"), + (null, null)).toDF("id2", "c2") + val dataset = d1.crossJoin(d2) + + def get_output(mode: String): DataFrame = { + val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode) + formula.fit(dataset).transform(dataset).select("features", "label") + } + + assert(intercept[SparkException](get_output("error").collect()) + .getMessage.contains("Encountered null while assembling a row")) + assert(get_output("skip").count() == 4) + assert(get_output("keep").count() == 6) + } + } From d65e531b44a388fed25d3cbf28fdce5a2d0598e6 Mon Sep 17 00:00:00 2001 From: JiahuiJiang Date: Thu, 5 Apr 2018 20:06:08 -0700 Subject: [PATCH 261/593] [SPARK-23823][SQL] Keep origin in transformExpression Fixes https://issues.apache.org/jira/browse/SPARK-23823 Keep origin for all the methods using transformExpression ## What changes were proposed in this pull request? Keep origin in transformExpression ## How was this patch tested? Manually tested that this fixes https://issues.apache.org/jira/browse/SPARK-23823 and columns have correct origins after Analyzer.analyze Author: JiahuiJiang Author: Jiahui Jiang Closes #20961 from JiahuiJiang/jj/keep-origin. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++- .../sql/catalyst/plans/QueryPlanSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ddf2cbf2ab911..64cb8c726772f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -103,7 +103,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT var changed = false @inline def transformExpression(e: Expression): Expression = { - val newE = f(e) + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } if (newE.fastEquals(e)) { e } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala new file mode 100644 index 0000000000000..27914ef5565c0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.IntegerType + +class QueryPlanSuite extends SparkFunSuite { + + test("origin remains the same after mapExpressions (SPARK-23823)") { + CurrentOrigin.setPosition(0, 0) + val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) + val query = plans.DslLogicalPlan(plans.table("table")).select(column) + CurrentOrigin.reset() + + val mappedQuery = query mapExpressions { + case _: Expression => Literal(1) + } + + val mappedOrigin = mappedQuery.expressions.apply(0).origin + assert(mappedOrigin == Origin.apply(Some(0), Some(0))) + } + +} From 249007e37f51f00d14e596692aeac87fbc10b520 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 5 Apr 2018 20:19:25 -0700 Subject: [PATCH 262/593] [SPARK-19724][SQL] create a managed table with an existed default table should throw an exception ## What changes were proposed in this pull request? This PR is to finish https://github.com/apache/spark/pull/17272 This JIRA is a follow up work after SPARK-19583 As we discussed in that PR The following DDL for a managed table with an existed default location should throw an exception: CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... CREATE TABLE ... (PARTITIONED BY ...) Currently there are some situations which are not consist with above logic: CREATE TABLE ... (PARTITIONED BY ...) succeed with an existed default location situation: for both hive/datasource(with HiveExternalCatalog/InMemoryCatalog) CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... situation: hive table succeed with an existed default location This PR is going to make above two situations consist with the logic that it should throw an exception with an existed default location. ## How was this patch tested? unit test added Author: Gengliang Wang Closes #20886 from gengliangwang/pr-17272. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../command/createDataSourceTables.scala | 5 +- .../spark/sql/StatisticsCollectionSuite.scala | 7 ++ .../sql/execution/command/DDLSuite.scala | 67 +++++++++++++++++++ 6 files changed, 110 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2b393f30d1435..9822d669050d5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1809,6 +1809,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 64e7ca11270b4..52ed89ef8d781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -289,6 +289,7 @@ class SessionCatalog( def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) validateName(table) val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined @@ -298,15 +299,33 @@ class SessionCatalog( makeQualifiedPath(tableDefinition.storage.locationUri.get) tableDefinition.copy( storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), - identifier = TableIdentifier(table, Some(db))) + identifier = tableIdentifier) } else { - tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + tableDefinition.copy(identifier = tableIdentifier) } requireDbExists(db) + if (!ignoreIfExists) { + validateTableLocation(newTableDefinition) + } externalCatalog.createTable(newTableDefinition, ignoreIfExists) } + def validateTableLocation(table: CatalogTable): Unit = { + // SPARK-19724: the default location of a managed table should be non-existent or empty. + if (table.tableType == CatalogTableType.MANAGED && + !conf.allowCreatingManagedTableUsingNonemptyLocation) { + val tableLocation = + new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier))) + val fs = tableLocation.getFileSystem(hadoopConf) + + if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) { + throw new AnalysisException(s"Can not create the managed table('${table.identifier}')" + + s". The associated location('${tableLocation.toString}') already exists.") + } + } + } + /** * Alter the metadata of an existing metastore table identified by `tableDefinition`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 13f31a6b2eb93..1c8ab9c62623e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1159,6 +1159,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = + buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation") + .internal() + .doc("When this option is set to true, creating managed tables with nonempty location " + + "is allowed. Otherwise, an analysis exception is thrown. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1581,6 +1589,9 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = + getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index e9747769dfcfc..f7c3e9b019258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -167,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand( sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) - + sparkSession.sessionState.catalog.validateTableLocation(table) val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { Some(sessionState.catalog.defaultTablePath(table.identifier)) } else { @@ -181,7 +181,8 @@ case class CreateDataSourceTableAsSelectCommand( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) - sessionState.catalog.createTable(newTable, ignoreIfExists = false) + // Table location is already validated. No need to check it again during table creation. + sessionState.catalog.createTable(newTable, ignoreIfExists = true) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ed4ea0231f1a7..14a565863d66c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier @@ -26,6 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -242,6 +245,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier(table))) Seq(false, true).foreach { autoUpdate => withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { @@ -269,6 +273,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + // SPARK-19724: clean up the previous table location. + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4df8fbfe1c0db..4304d0b6f6b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -180,6 +180,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private val escapedIdentifier = "`(.+)`".r + private def dataSource: String = { + if (isUsingHiveMetastore) { + "HIVE" + } else { + "PARQUET" + } + } protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { @@ -365,6 +372,66 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("CTAS a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + sql("INSERT INTO tab1 VALUES (1, 'a')") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing non-empty directory") { + withTable("tab1") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + // create an empty hidden file + tableLoc.mkdir() + val hiddenGarbageFile = new File(tableLoc.getCanonicalPath, ".garbage") + hiddenGarbageFile.createNewFile() + val exMsg = "Can not create the managed table('`tab1`'). The associated location" + val exMsgWithDefaultDB = + "Can not create the managed table('`default`.`tab1`'). The associated location" + var ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + }.getMessage + if (isUsingHiveMetastore) { + assert(ex.contains(exMsgWithDefaultDB)) + } else { + assert(ex.contains(exMsg)) + } + + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], From 6ade5cbb498f6c6ea38779b97f2325d5cf5013f2 Mon Sep 17 00:00:00 2001 From: Daniel Sakuma Date: Fri, 6 Apr 2018 13:37:08 +0800 Subject: [PATCH 263/593] [MINOR][DOC] Fix some typos and grammar issues ## What changes were proposed in this pull request? Easy fix in the documentation. ## How was this patch tested? N/A Closes #20948 Author: Daniel Sakuma Closes #20928 from dsakuma/fix_typo_configuration_docs. --- docs/README.md | 2 +- docs/_plugins/include_example.rb | 2 +- docs/building-spark.md | 2 +- docs/cloud-integration.md | 4 +-- docs/configuration.md | 20 ++++++------ docs/css/pygments-default.css | 2 +- docs/graphx-programming-guide.md | 4 +-- docs/job-scheduling.md | 4 +-- docs/ml-advanced.md | 2 +- docs/ml-classification-regression.md | 6 ++-- docs/ml-collaborative-filtering.md | 2 +- docs/ml-features.md | 2 +- docs/ml-migration-guides.md | 2 +- docs/ml-tuning.md | 2 +- docs/mllib-clustering.md | 2 +- docs/mllib-collaborative-filtering.md | 4 +-- docs/mllib-data-types.md | 2 +- docs/mllib-dimensionality-reduction.md | 2 +- docs/mllib-evaluation-metrics.md | 2 +- docs/mllib-feature-extraction.md | 2 +- docs/mllib-isotonic-regression.md | 4 +-- docs/mllib-linear-methods.md | 2 +- docs/mllib-optimization.md | 4 +-- docs/monitoring.md | 4 +-- docs/quick-start.md | 6 ++-- docs/rdd-programming-guide.md | 2 +- docs/running-on-kubernetes.md | 4 +-- docs/running-on-mesos.md | 12 +++---- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/spark-standalone.md | 2 +- docs/sparkr.md | 6 ++-- docs/sql-programming-guide.md | 32 +++++++++---------- docs/storage-openstack-swift.md | 2 +- docs/streaming-flume-integration.md | 6 ++-- docs/streaming-kafka-0-8-integration.md | 10 +++--- docs/streaming-programming-guide.md | 26 +++++++-------- .../structured-streaming-kafka-integration.md | 2 +- .../structured-streaming-programming-guide.md | 8 ++--- docs/submitting-applications.md | 2 +- docs/tuning.md | 2 +- python/README.md | 2 +- sql/README.md | 2 +- 43 files changed, 107 insertions(+), 107 deletions(-) diff --git a/docs/README.md b/docs/README.md index 225bb1b2040de..9eac4ba35c458 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,7 +5,7 @@ here with the Spark source code. You can also find documentation specific to rel Spark at http://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the -documentation yourself. Why build it yourself? So that you have the docs that corresponds to +documentation yourself. Why build it yourself? So that you have the docs that correspond to whichever version of Spark you currently have checked out of revision control. ## Prerequisites diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ea1d438f529e..1e91f12518e0b 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -48,7 +48,7 @@ def render(context) begin code = File.open(@file).read.encode("UTF-8") rescue => e - # We need to explicitly exit on execptions here because Jekyll will silently swallow + # We need to explicitly exit on exceptions here because Jekyll will silently swallow # them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104) puts(e) puts(e.backtrace) diff --git a/docs/building-spark.md b/docs/building-spark.md index c391255a91596..0236bb05849ad 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -113,7 +113,7 @@ Note: Flume support is deprecated as of Spark 2.3.0. ## Building submodules individually -It's possible to build Spark sub-modules using the `mvn -pl` option. +It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index c150d9efc06ff..ac1c336988930 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -27,13 +27,13 @@ description: Introduction to cloud storage support in Apache Spark SPARK_VERSION All major cloud providers offer persistent data storage in *object stores*. These are not classic "POSIX" file systems. In order to store hundreds of petabytes of data without any single points of failure, -object stores replace the classic filesystem directory tree +object stores replace the classic file system directory tree with a simpler model of `object-name => data`. To enable remote access, operations on objects are usually offered as (slow) HTTP REST operations. Spark can read and write data in object stores through filesystem connectors implemented in Hadoop or provided by the infrastructure suppliers themselves. -These connectors make the object stores look *almost* like filesystems, with directories and files +These connectors make the object stores look *almost* like file systems, with directories and files and the classic operations on them such as list, delete and rename. diff --git a/docs/configuration.md b/docs/configuration.md index 2eb6a77434ea6..4d4d0c58dd07d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -558,7 +558,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1288,7 +1288,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1513,7 +1513,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1722,7 +1722,7 @@ Apart from these, the following properties are also available, and may be useful When spark.task.reaper.enabled = true, this setting specifies a timeout after which the executor JVM will kill itself if a killed task has not stopped running. The default value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose - of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + of this setting is to act as a safety-net to prevent runaway noncancellable tasks from rendering an executor unusable. @@ -1915,8 +1915,8 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1971,7 +1971,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1980,7 +1980,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -2178,7 +2178,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defalut.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. diff --git a/docs/css/pygments-default.css b/docs/css/pygments-default.css index 6247cd8396cf1..a4d583b366603 100644 --- a/docs/css/pygments-default.css +++ b/docs/css/pygments-default.css @@ -5,7 +5,7 @@ To generate this, I had to run But first I had to install pygments via easy_install pygments I had to override the conflicting bootstrap style rules by linking to -this stylesheet lower in the html than the bootstap css. +this stylesheet lower in the html than the bootstrap css. Also, I was thrown off for a while at first when I was using markdown code block inside my {% highlight scala %} ... {% endhighlight %} tags diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 5c97a248df4bc..35293348e3f3d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -491,7 +491,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts)( The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices` except that the user defined `map` function is applied to all vertices and can change the vertex property type. Because not all vertices may have a matching value in the input RDD the `map` -function takes an `Option` type. For example, we can setup a graph for PageRank by initializing +function takes an `Option` type. For example, we can set up a graph for PageRank by initializing vertex properties with their `outDegree`. @@ -969,7 +969,7 @@ A vertex is part of a triangle when it has two adjacent vertices with an edge be # Examples Suppose I want to build a graph from some text files, restrict the graph -to important relationships and users, run page-rank on the sub-graph, and +to important relationships and users, run page-rank on the subgraph, and then finally return attributes associated with the top users. I can do all of this in just a few lines with GraphX: diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index e6d881639a13b..da90342406c84 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -23,7 +23,7 @@ run tasks and store data for that application. If multiple users need to share y different options to manage allocation, depending on the cluster manager. The simplest option, available on all cluster managers, is _static partitioning_ of resources. With -this approach, each application is given a maximum amount of resources it can use, and holds onto them +this approach, each application is given a maximum amount of resources it can use and holds onto them for its whole duration. This is the approach used in Spark's [standalone](spark-standalone.html) and [YARN](running-on-yarn.html) modes, as well as the [coarse-grained Mesos mode](running-on-mesos.html#mesos-run-modes). @@ -230,7 +230,7 @@ properties: * `minShare`: Apart from an overall weight, each pool can be given a _minimum shares_ (as a number of CPU cores) that the administrator would like it to have. The fair scheduler always attempts to meet all active pools' minimum shares before redistributing extra resources according to the weights. - The `minShare` property can therefore be another way to ensure that a pool can always get up to a + The `minShare` property can, therefore, be another way to ensure that a pool can always get up to a certain number of resources (e.g. 10 cores) quickly without giving it a high priority for the rest of the cluster. By default, each pool's `minShare` is 0. diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 2747f2df7cb10..375957e92cc4c 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -77,7 +77,7 @@ Quasi-Newton methods in this case. This fallback is currently always enabled for L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively. -In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. +In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features is no more than 4096. For larger problems, use L-BFGS instead. ## Iteratively reweighted least squares (IRLS) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index ddd2f4b49ca07..d660655e193eb 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -420,7 +420,7 @@ Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator`. For the base classifier, it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. @@ -908,7 +908,7 @@ Refer to the [R API docs](api/R/spark.survreg.html) for more details. belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -927,7 +927,7 @@ We implement a which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns -label, features and weight. Additionally IsotonicRegression algorithm has one +label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 58f2d4b531e70..8b0f287dc39ad 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -35,7 +35,7 @@ but the ids must be within the integer value range. ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. diff --git a/docs/ml-features.md b/docs/ml-features.md index 3370eb3893272..7aed2341584fc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1174,7 +1174,7 @@ for more details on the API. ## SQLTransformer `SQLTransformer` implements the transformations which are defined by SQL statement. -Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +Currently, we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` where `"__THIS__"` represents the underlying table of the input dataset. The select clause specifies the fields, constants, and expressions to display in the output, and can be any select clause that Spark SQL supports. Users can also diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index f4b0df58cf63b..e4736411fb5fe 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -347,7 +347,7 @@ rather than using the old parameter class `Strategy`. These new training method separate classification and regression, and they replace specialized parameter types with simple `String` types. -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the [Decision Trees Guide](mllib-decision-tree.html#examples). ## From 0.9 to 1.0 diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 54d9cd21909df..028bfec465bab 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -103,7 +103,7 @@ Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.m In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in - the case of `CrossValidator`. It is therefore less expensive, + the case of `CrossValidator`. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large. Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index df2be92d860e4..dc6b095f5d59b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -42,7 +42,7 @@ The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute Within -Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the +Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact, the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 76a00f18b3b90..b2300028e151b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -31,7 +31,7 @@ following parameters: ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. @@ -60,7 +60,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
    -In the following example we load rating data. Each row consists of a user, a product and a rating. +In the following example, we load rating data. Each row consists of a user, a product and a rating. We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS$) method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 35cee3275e3b5..5066bb29387dc 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -350,7 +350,7 @@ which is a tuple of `(Int, Int, Matrix)`. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -In general the use of non-deterministic RDDs can lead to errors. +In general, the use of non-deterministic RDDs can lead to errors. ### RowMatrix diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index a72680d52a26c..4e6b4530942f1 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -91,7 +91,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an [Principal component analysis (PCA)](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical method to find a rotation such that the first coordinate has the largest variance -possible, and each succeeding coordinate in turn has the largest variance possible. The columns of +possible, and each succeeding coordinate, in turn, has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. `spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7f277543d2e9a..d9dbbab4840a3 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -13,7 +13,7 @@ of the model on some criteria, which depends on the application and its requirem suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, -regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 8b89296b14cdd..bb29f65c0322f 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -105,7 +105,7 @@ p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top \]` where $V$ is the vocabulary size. -The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to $O(\log(V))$ diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index ca84551506b2b..99cab98c690c6 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -9,7 +9,7 @@ displayTitle: Regression - RDD-based API belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -28,7 +28,7 @@ best fitting the original data points. which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent -label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 034e89e25000e..73f6e206ca543 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -425,7 +425,7 @@ We create our model by initializing the weights to zero and register the streams testing then start the job. Printing predictions alongside true labels lets us easily see the result. -Finally we can save text files with data to the training or testing folders. +Finally, we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` the model will update. Anytime a text file is placed in `args(1)` you will see predictions. diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 14d76a6e41e23..04758903da89c 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -121,7 +121,7 @@ computation of the sum of the partial results from each worker machine is perfor standard spark routines. If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in -each iteration is exact (sub)gradient descent. In this case there is no randomness and no +each iteration is exact (sub)gradient descent. In this case, there is no randomness and no variance in the used step directions. On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to @@ -135,7 +135,7 @@ algorithm in the family of quasi-Newton methods to solve the optimization proble quadratic without evaluating the second partial derivatives of the objective function to construct the Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no vertical scalability issue (the number of training features) when computing the Hessian matrix -explicitly in Newton's method. As a result, L-BFGS often achieves rapider convergence compared with +explicitly in Newton's method. As a result, L-BFGS often achieves more rapid convergence compared with other first-order optimization. ### Choosing an Optimization Method diff --git a/docs/monitoring.md b/docs/monitoring.md index 01736c77b0979..6eaf33135744d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -214,7 +214,7 @@ incomplete attempt or the final successful attempt. 2. Incomplete applications are only updated intermittently. The time between updates is defined by the interval between checks for changed files (`spark.history.fs.update.interval`). -On larger clusters the update interval may be set to large values. +On larger clusters, the update interval may be set to large values. The way to view a running application is actually to view its own web UI. 3. Applications which exited without registering themselves as completed will be listed @@ -422,7 +422,7 @@ configuration property. If, say, users wanted to set the metrics namespace to the name of the application, they can set the `spark.metrics.namespace` property to a value like `${spark.app.name}`. This value is then expanded appropriately by Spark and is used as the root namespace of the metrics system. -Non driver and executor metrics are never prefixed with `spark.app.id`, nor does the +Non-driver and executor metrics are never prefixed with `spark.app.id`, nor does the `spark.metrics.namespace` property have any such affect on such metrics. Spark's metrics are decoupled into different diff --git a/docs/quick-start.md b/docs/quick-start.md index 07c520cbee6be..f1a2096cd4dbd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,11 +11,11 @@ This tutorial provides a quick introduction to using Spark. We will first introd interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -To follow along with this guide, first download a packaged release of Spark from the +To follow along with this guide, first, download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. -Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. # Interactive Analysis with the Spark Shell @@ -47,7 +47,7 @@ scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. +Now let's transform this Dataset into a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 2e29aef7f21a2..b6424090d2fea 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -818,7 +818,7 @@ The behavior of the above code is undefined, and may not work as intended. To ex The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. +In local mode, in some circumstances, the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 9c4644947c911..e9e1f3e280609 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -17,7 +17,7 @@ container images and entrypoints.** * A runnable distribution of Spark 2.3 or above. * A running Kubernetes cluster at version >= 1.6 with access configured to it using [kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster, -you may setup a test cluster on your local machine using +you may set up a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. * Be aware that the default minikube configuration is not enough for running Spark applications. @@ -221,7 +221,7 @@ that allows driver pods to create pods and services under the default Kubernetes [RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to be used by the driver pod through the configuration property -`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod +`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example, to make the driver pod use the `spark` service account, a user simply adds the following option to the `spark-submit` command: ``` diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8e58892e2689f..3c2a1501ca692 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -90,7 +90,7 @@ Depending on your deployment environment you may wish to create a single set of Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. -Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. +Additionally, if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. ### Credential Specification Preference Order @@ -225,7 +225,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the Spark application will consume all resources offered to it by Mesos, -so we of course urge you to set this variable in any sort of +so we, of course, urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -233,14 +233,14 @@ The scheduler will start executors round-robin on the offers Mesos gives it, but there are no spread guarantees, as Mesos does not provide such guarantees on the offer stream. -In this mode spark executors will honor port allocation if such is -provided from the user. Specifically if the user defines +In this mode Spark executors will honor port allocation if such is +provided from the user. Specifically, if the user defines `spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the user, ephemeral ports are used as usual. This port honouring implementation -implies one task per host if the user defines a port. In the future network +implies one task per host if the user defines a port. In the future network, isolation shall be supported. The benefit of coarse-grained mode is much lower startup overhead, but @@ -486,7 +486,7 @@ See the [configuration page](configuration.html) for information on Spark config
    - + @@ -1797,6 +1798,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + + + + From 2a24c481da3f30b510deb62e5cf21c9463cf250c Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 24 Apr 2018 09:25:41 -0700 Subject: [PATCH 371/593] [SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features ## What changes were proposed in this pull request? - Multiple possible input types is added in validateAndTransformSchema() and computeCost() while checking column type - Add if statement in transform() to support array type as featuresCol - Add the case statement in fit() while selecting columns from dataset These changes will be applied to KMeans first, then to other clustering method ## How was this patch tested? unit test is added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21081 from ludatabricks/SPARK-23975. --- .../apache/spark/ml/clustering/KMeans.scala | 32 +++++++--- .../apache/spark/ml/util/DatasetUtils.scala | 63 +++++++++++++++++++ .../spark/ml/clustering/KMeansSuite.scala | 38 +++++++++++ 3 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1ad157a695a7d..d475c726e6f08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,13 +86,24 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** + * Validates the input schema. + * @param schema input schema + */ + private[clustering] def validateSchema(schema: StructType): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + + SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) + } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + validateSchema(schema) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -125,8 +136,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -146,8 +160,10 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + validateSchema(dataset.schema) + + val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } parentModel.computeCost(data) @@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select( + DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 0000000000000..52619cb65489a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vectors, VectorUDT} +import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 77c9d482d95b6..5445ebe5c95eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + val featuresColNameD = "array_double_features" + val featuresColNameF = "array_float_features" + + val doubleUDF = udf { (features: Vector) => + val featureArray = Array.fill[Double](features.size)(0.0) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + val floatUDF = udf { (features: Vector) => + val featureArray = Array.fill[Float](features.size)(0.0f) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + + val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) + .drop("features") + val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) + .drop("features") + assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) + assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) + + val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) + val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) + val modelD = kmeansD.fit(newdatasetD) + val modelF = kmeansF.fit(newdatasetF) + val transformedD = modelD.transform(newdatasetD) + val transformedF = modelF.transform(newdatasetF) + + val predictDifference = transformedD.select("prediction") + .except(transformedF.select("prediction")) + assert(predictDifference.count() == 0) + assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) From ce7ba2e98e0a3b038e881c271b5905058c43155b Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Tue, 24 Apr 2018 09:57:09 -0700 Subject: [PATCH 372/593] [SPARK-23807][BUILD] Add Hadoop 3.1 profile with relevant POM fix ups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. Adds a `hadoop-3.1` profile build depending on the hadoop-3.1 artifacts. 1. In the hadoop-cloud module, adds an explicit hadoop-3.1 profile which switches from explicitly pulling in cloud connectors (hadoop-openstack, hadoop-aws, hadoop-azure) to depending on the hadoop-cloudstorage POM artifact, which pulls these in, has pre-excluded things like hadoop-common, and stays up to date with new connectors (hadoop-azuredatalake, hadoop-allyun). Goal: it becomes the Hadoop projects homework of keeping this clean, and the spark project doesn't need to handle new hadoop releases adding more dependencies. 1. the hadoop-cloud/hadoop-3.1 profile also declares support for jetty-ajax and jetty-util to ensure that these jars get into the distribution jar directory when needed by unshaded libraries. 1. Increases the curator and zookeeper versions to match those in hadoop-3, fixing spark core to build in sbt with the hadoop-3 dependencies. ## How was this patch tested? * Everything this has been built and tested against both ASF Hadoop branch-3.1 and hadoop trunk. * spark-shell was used to create connectors to all the stores and verify that file IO could take place. The spark hive-1.2.1 JAR has problems here, as it's version check logic fails for Hadoop versions > 2. This can be avoided with either of * The hadoop JARs built to declare their version as Hadoop 2.11 `mvn install -DskipTests -DskipShade -Ddeclared.hadoop.version=2.11` . This is safe for local test runs, not for deployment (HDFS is very strict about cross-version deployment). * A modified version of spark hive whose version check switch statement is happy with hadoop 3. I've done both, with maven and SBT. Three issues surfaced 1. A spark-core test failure —fixed in SPARK-23787. 1. SBT only: Zookeeper not being found in spark-core. Somehow curator 2.12.0 triggers some slightly different dependency resolution logic from previous versions, and Ivy was missing zookeeper.jar entirely. This patch adds the explicit declaration for all spark profiles, setting the ZK version = 3.4.9 for hadoop-3.1 1. Marking jetty-utils as provided in spark was stopping hadoop-azure from being able to instantiate the azure wasb:// client; it was using jetty-util-ajax, which could then not find a class in jetty-util. Author: Steve Loughran Closes #20923 from steveloughran/cloud/SPARK-23807-hadoop-31. --- assembly/pom.xml | 8 ++ core/pom.xml | 6 + dev/deps/spark-deps-hadoop-3.1 | 221 +++++++++++++++++++++++++++++++++ dev/test-dependencies.sh | 1 + hadoop-cloud/pom.xml | 83 ++++++++++++- pom.xml | 9 ++ 6 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 dev/deps/spark-deps-hadoop-3.1 diff --git a/assembly/pom.xml b/assembly/pom.xml index a207dae5a74ff..9608c96fd5369 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/core/pom.xml b/core/pom.xml index 9258a856028a0..093a9869b6dd7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,12 @@ org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 0000000000000..97ad65a4096cb --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,221 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.8.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar +automaton-1.11-8.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.58.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.8.jar +commons-compress-1.4.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.4.jar +httpcore-4.4.8.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.8.jar +java-xmlbuilder-1.1.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jets3t-0.9.4.jar +jetty-webapp-9.3.20.v20170531.jar +jetty-xml-9.3.20.v20170531.jar +jline-2.12.1.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.8.2.jar +parquet-common-1.8.2.jar +parquet-encoding-1.8.2.jar +parquet-format-2.3.1.jar +parquet-hadoop-1.8.2.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.8.2.jar +protobuf-java-2.5.0.jar +py4j-0.10.6.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.5.9.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm5-shaded-4.4.jar +xz-1.0.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3bf7618e1ea96..2fbd6b5e98f7f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -34,6 +34,7 @@ MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 8e424b1c50236..2c39a7df0146e 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/pom.xml b/pom.xml index 0a711f287a53f..88e77ff874748 100644 --- a/pom.xml +++ b/pom.xml @@ -2671,6 +2671,15 @@ + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn From 83013752e3cfcbc3edeef249439ac20b143eeabc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Apr 2018 10:40:25 -0700 Subject: [PATCH 373/593] [SPARK-23455][ML] Default Params in ML should be saved separately in metadata ## What changes were proposed in this pull request? We save ML's user-supplied params and default params as one entity in metadata. During loading the saved models, we set all the loaded params into created ML model instances as user-supplied params. It causes some problems, e.g., if we strictly disallow some params to be set at the same time, a default param can fail the param check because it is treated as user-supplied param after loading. The loaded default params should not be set as user-supplied params. We should save ML default params separately in metadata. For backward compatibility, when loading metadata, if it is a metadata file from previous Spark, we shouldn't raise error if we can't find the default param field. ## How was this patch tested? Pass existing tests and added tests. Author: Liang-Chi Hsieh Closes #20633 from viirya/save-ml-default-params. --- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 4 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 4 +- .../RandomForestClassifier.scala | 4 +- .../spark/ml/clustering/BisectingKMeans.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../apache/spark/ml/clustering/KMeans.scala | 2 +- .../org/apache/spark/ml/clustering/LDA.scala | 4 +- .../feature/BucketedRandomProjectionLSH.scala | 2 +- .../apache/spark/ml/feature/Bucketizer.scala | 24 ---- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../org/apache/spark/ml/feature/Imputer.scala | 2 +- .../spark/ml/feature/MaxAbsScaler.scala | 2 +- .../apache/spark/ml/feature/MinHashLSH.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../ml/feature/OneHotEncoderEstimator.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../ml/feature/QuantileDiscretizer.scala | 24 ---- .../apache/spark/ml/feature/RFormula.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 13 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 4 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 4 +- .../spark/ml/tuning/CrossValidator.scala | 6 +- .../ml/tuning/TrainValidationSplit.scala | 6 +- .../org/apache/spark/ml/util/ReadWrite.scala | 130 ++++++++++++------ .../spark/ml/util/DefaultReadWriteTest.scala | 73 +++++++++- project/MimaExcludes.scala | 6 + 44 files changed, 223 insertions(+), 147 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 771cd4fe91dcf..57797d1cc4978 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -279,7 +279,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) val model = new DecisionTreeClassificationModel(metadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c0255103bc313..0aa24f0a3cfcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -379,14 +379,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8f950cd28c3aa..80c537e1e0eb2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ee4b01058c75c..e426263910f26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1270,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index af2e4699924e5..57ba47e596a97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d47435..45fb585ed2262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 5348d882cfd67..7df53a6b8ad10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -289,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -484,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bb972e9706fc1..f1ef26a07d3f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -319,14 +319,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica case (treeMetadata, root) => val tree = new DecisionTreeClassificationModel(treeMetadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f7c422dc0faea..addc12ac52ec1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -193,7 +193,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a6938..b5804900c0358 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -233,7 +233,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d475c726e6f08..de61c9c089a36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -280,7 +280,7 @@ object KMeansModel extends MLReadable[KMeansModel] { sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159f..47077230fac0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +391,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 41eaaf9679914..a906e954fecd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -238,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f49c410cbcfe2..f99649f7fa164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } - - private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea3..dbfb199ccd58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 9e0ed437e7bfc..10c48c3f52085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -363,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb8..58897cca4e5c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db8..1c074e204ad99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f67..90eceb0d61b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 556848e45532d..a67a3b0abbc1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -205,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..2e0ae4af66f06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index bd1e3426c8780..4a44f3186538d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d7930..8172491a517d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 3b4c25478fb1d..56e2c543d100a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - private[QuantileDiscretizer] - class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e214765e3307f..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d2..91b0707dec3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 67cdb097217a2..a833d8b270cf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff0..0e7396a621dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d6..fc9996d69ba72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 3d041fc80eb7f..0bf405d9abf9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -335,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9a83a5882ce29..e6c347ed17c15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0e..a23f9552b9e5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75f..7c6ec2a8419fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 5cef5c9f21f1e..8bcf0793a64c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) val model = new DecisionTreeRegressionModel(metadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 834aaa0e362d1..8598e808c4946 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -311,7 +311,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +319,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 4c3f1431d5077..e030a40cb19be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1146,7 +1146,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474b..b046897ab2b7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 9cdd3a051e719..f1d9a4453deaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -799,7 +799,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f77398ba2a22..4509f85aafd12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -276,14 +276,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c2826dcc08634..5e916cc4a9fdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -424,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8d1b9a8ddab59..13369c4df7180 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -407,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7edcd498678cc..72a60e04360d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434ee..4d9e664850c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a87fa68422c34..7d0e88ee20c3f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,6 +62,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), From 379bffa0525a4343f8c10e51ed192031922f9874 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 24 Apr 2018 11:02:22 -0700 Subject: [PATCH 374/593] [SPARK-23990][ML] Instruments logging improvements - ML regression package ## What changes were proposed in this pull request? Instruments logging improvements - ML regression package I add an `OptionalInstrument` class which used in `WeightLeastSquares` and `IterativelyReweightedLeastSquares`. ## How was this patch tested? N/A Author: WeichenXu Closes #21078 from WeichenXu123/inst_reg. --- .../classification/LogisticRegression.scala | 4 +- .../IterativelyReweightedLeastSquares.scala | 18 +++-- .../spark/ml/optim/WeightedLeastSquares.scala | 32 +++++---- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../GeneralizedLinearRegression.scala | 14 ++-- .../ml/regression/LinearRegression.scala | 22 +++--- .../spark/ml/tree/impl/RandomForest.scala | 2 + .../spark/ml/util/Instrumentation.scala | 68 ++++++++++++++++++- 8 files changed, 125 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e426263910f26..06ca37bc75146 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -816,7 +816,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4d..572b8cf0051b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd29..1b7c15f1f0a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 7c6ec2a8419fd..e27a96e1f5dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index e030a40cb19be..143c8a3548b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -492,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -501,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f1d9a4453deaa..c45ade94a4e33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -339,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -378,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -385,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -415,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -430,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -522,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 056a94b351f79..905870178e549 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -108,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index e694bc27b2f1e..3247c394dfa64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.util.UUID +import scala.reflect.ClassTag + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -40,7 +42,8 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { + val estimator: E, + val dataset: RDD[_]) extends Logging { private val id = UUID.randomUUID() private val prefix = { @@ -103,6 +106,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( logNamedValue(Instrumentation.loggerTags.numClasses, num) } + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) + } + /** * Logs the value with customized name field. */ @@ -114,6 +121,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Double): Unit = { + log(compact(render(name -> value))) + } + /** * Logs the successful completion of the training session. */ @@ -131,6 +142,8 @@ private[spark] object Instrumentation { val numFeatures = "numFeatures" val numClasses = "numClasses" val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" } /** @@ -150,3 +163,56 @@ private[spark] object Instrumentation { } } + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } + } + + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + + /** + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. + */ + def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), + instr.estimator.getClass.getName.stripSuffix("$")) + } + + /** + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. + */ + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) + } +} From 7b1e6523af3c96043aa8d2763e5f18b6e2781c3d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Apr 2018 14:33:33 -0700 Subject: [PATCH 375/593] [SPARK-24056][SS] Make consumer creation lazy in Kafka source for Structured streaming ## What changes were proposed in this pull request? Currently, the driver side of the Kafka source (i.e. KafkaMicroBatchReader) eagerly creates a consumer as soon as the Kafk aMicroBatchReader is created. However, we create dummy KafkaMicroBatchReader to get the schema and immediately stop it. Its better to make the consumer creation lazy, it will be created on the first attempt to fetch offsets using the KafkaOffsetReader. ## How was this patch tested? Existing unit tests Author: Tathagata Das Closes #21134 from tdas/SPARK-24056. --- .../sql/kafka010/KafkaOffsetReader.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..82066697cb95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } From d6c26d1c9a8f747a3e0d281a27ea9eb4d92102e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 24 Apr 2018 17:06:03 -0700 Subject: [PATCH 376/593] [SPARK-24038][SS] Refactor continuous writing to its own class ## What changes were proposed in this pull request? Refactor continuous writing to its own class. See WIP https://github.com/jose-torres/spark/pull/13 for the overall direction this is going, but I think this PR is very isolated and necessary anyway. ## How was this patch tested? existing unit tests - refactoring only Author: Jose Torres Closes #21116 from jose-torres/SPARK-24038. --- .../datasources/v2/DataSourceV2Strategy.scala | 4 + .../datasources/v2/WriteToDataSourceV2.scala | 74 +---------- .../continuous/ContinuousExecution.scala | 2 +- .../WriteToContinuousDataSource.scala | 31 +++++ .../WriteToContinuousDataSourceExec.scala | 124 ++++++++++++++++++ 5 files changed, 165 insertions(+), 70 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1ac9572de6412..c2a31442d2be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -32,6 +33,9 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index e80b44c1cdc66..ea283ed77efda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -65,25 +65,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e s"The input RDD has ${messages.length} partitions.") try { - val runTask = writer match { - // This case means that we're doing continuous processing. In microbatch streaming, the - // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. - case w: StreamWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) - } - sparkContext.runJob( rdd, - runTask, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message @@ -91,14 +76,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } ) - if (!writer.isInstanceOf[StreamWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { - case _: InterruptedException if writer.isInstanceOf[StreamWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -111,8 +92,6 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } logError(s"Data source writer $writer aborted.") cause match { - // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) case _ => throw cause @@ -168,49 +147,6 @@ object DataWritingSparkTask extends Logging { logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } } class InternalRowDataWriterFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 951d694355ec5..f58146ac42398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -199,7 +199,7 @@ class ContinuousExecution( triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) val reader = withSink.collect { case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 0000000000000..943c731a70529 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 0000000000000..ba88ae1af469a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.util.Utils + +/** + * The physical plan for writing data into a continuous processing [[StreamWriter]]. + */ +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${rdd.getNumPartitions} partitions.") + // Let the epoch coordinator know how many partitions the write RDD has. + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + WriteToContinuousDataSourceExec.run(writerFactory, context, iter), + rdd.partitions.indices) + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} + +object WriteToContinuousDataSourceExec extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): Unit = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + do { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } while (!context.isInterrupted()) + } +} From 5fea17b3befc50aef59b799711d03b9552f21b19 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 25 Apr 2018 11:19:08 +0900 Subject: [PATCH 377/593] [SPARK-23821][SQL] Collection function: flatten ## What changes were proposed in this pull request? This PR adds a new collection function that transforms an array of arrays into a single array. The PR comprises: - An expression for flattening array structure - Flatten function - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(Seq(1, 2), Seq(4, 5)), Seq(null, Seq(1)) ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 064 */ project_numElements, /* 065 */ 4); /* 066 */ if (project_size > 2147483632) { /* 067 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 068 */ project_size + " bytes of data due to exceeding the limit 2147483632" + /* 069 */ " bytes for UnsafeArrayData."); /* 070 */ } /* 071 */ /* 072 */ byte[] project_array = new byte[(int)project_size]; /* 073 */ UnsafeArrayData project_tempArrayData = new UnsafeArrayData(); /* 074 */ Platform.putLong(project_array, 16, project_numElements); /* 075 */ project_tempArrayData.pointTo(project_array, 16, (int)project_size); /* 076 */ int project_counter = 0; /* 077 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 078 */ ArrayData arr = inputadapter_value.getArray(k); /* 079 */ for (int l = 0; l < arr.numElements(); l++) { /* 080 */ if (arr.isNullAt(l)) { /* 081 */ project_tempArrayData.setNullAt(project_counter); /* 082 */ } else { /* 083 */ project_tempArrayData.setInt( /* 084 */ project_counter, /* 085 */ arr.getInt(l) /* 086 */ ); /* 087 */ } /* 088 */ project_counter++; /* 089 */ } /* 090 */ } /* 091 */ project_value = project_tempArrayData; /* 092 */ /* 093 */ } /* 094 */ /* 095 */ } ``` ### Non-primitive type ``` val df = Seq( Seq(Seq("a", "b"), Seq(null, "d")), Seq(null, Seq("a")) ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ Object[] project_arrayObject = new Object[(int)project_numElements]; /* 064 */ int project_counter = 0; /* 065 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 066 */ ArrayData arr = inputadapter_value.getArray(k); /* 067 */ for (int l = 0; l < arr.numElements(); l++) { /* 068 */ project_arrayObject[project_counter] = arr.getUTF8String(l); /* 069 */ project_counter++; /* 070 */ } /* 071 */ } /* 072 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject); /* 073 */ /* 074 */ } /* 075 */ /* 076 */ } ``` Author: mn-mikke Closes #20938 from mn-mikke/feature/array-api-flatten-to-master. --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 176 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 95 ++++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 79 ++++++++ 6 files changed, 376 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..de53b48b6f3b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2191,6 +2191,23 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..6afcf309bd690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -413,6 +413,7 @@ object FunctionRegistry { expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[Flatten]("flatten"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..bc71b5f34ce4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression { override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if (!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + + | " bytes for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[(int)$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..b49fa76b2a781 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..d2f057310f89b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3340,6 +3340,14 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..03605c30036a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -691,6 +691,85 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 64e8408e6fa2d74929601b01a29771738f6d8c65 Mon Sep 17 00:00:00 2001 From: liutang123 Date: Wed, 25 Apr 2018 18:10:51 +0800 Subject: [PATCH 378/593] [SPARK-24012][SQL] Union of map and other compatible column ## What changes were proposed in this pull request? Union of map and other compatible column result in unresolved operator 'Union; exception Reproduction `spark-sql>select map(1,2), 'str' union all select map(1,2,3,null), 1` Output: ``` Error in query: unresolved operator 'Union;; 'Union :- Project [map(1, 2) AS map(1, 2)#106, str AS str#107] : +- OneRowRelation$ +- Project [map(1, cast(2 as int), 3, cast(null as int)) AS map(1, CAST(2 AS INT), 3, CAST(NULL AS INT))#109, 1 AS 1#108] +- OneRowRelation$ ``` So, we should cast part of columns to be compatible when appropriate. ## How was this patch tested? Added a test (query union of map and other columns) to SQLQueryTestSuite's union.sql. Author: liutang123 Closes #21100 from liutang123/SPARK-24012. --- .../sql/catalyst/analysis/TypeCoercion.scala | 8 ++++ .../test/resources/sql-tests/inputs/union.sql | 11 +++++ .../resources/sql-tests/results/union.sql.out | 42 ++++++++++++++----- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cfcbd8db559a3..25bad28a2a209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -112,6 +112,14 @@ object TypeCoercion { StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) })) + case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + + case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) + Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) + case _ => None } diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad033..6da1b9b49b226 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0cf..b023df825d814 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + From 20ca208bcda6f22fe7d9fb54144de435b4237536 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 25 Apr 2018 19:06:18 +0800 Subject: [PATCH 379/593] [SPARK-23880][SQL] Do not trigger any jobs for caching data ## What changes were proposed in this pull request? This pr fixed code so that `cache` could prevent any jobs from being triggered. For example, in the current master, an operation below triggers a actual job; ``` val df = spark.range(10000000000L) .filter('id > 1000) .orderBy('id.desc) .cache() ``` This triggers a job while the cache should be lazy. The problem is that, when creating `InMemoryRelation`, we build the RDD, which calls `SparkPlan.execute` and may trigger jobs, like sampling job for range partitioner, or broadcast job. This pr removed the code to build a cached `RDD` in the constructor of `InMemoryRelation` and added `CachedRDDBuilder` to lazily build the `RDD` in `InMemoryRelation`. Then, the first call of `CachedRDDBuilder.cachedColumnBuffers` triggers a job to materialize the cache in `InMemoryTableScanExec` . ## How was this patch tested? Added tests in `CachedTableSuite`. Author: Takeshi Yamamuro Closes #21018 from maropu/SPARK-23880. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 14 +- .../execution/columnar/InMemoryRelation.scala | 155 ++++++++++-------- .../columnar/InMemoryTableScanExec.scala | 10 +- .../apache/spark/sql/CachedTableSuite.scala | 36 +++- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 8 files changed, 133 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 917168162b236..cd4def71e6f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2933,7 +2933,7 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a8794be7280c7..93bf91e56f1bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -119,7 +119,7 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } @@ -138,16 +138,14 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + cd.cachedRepresentation.cacheBuilder.clearCache() // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, + cacheBuilder = cd.cachedRepresentation + .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a7ba9b86a176f..da35a4734e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - logicalPlan: LogicalPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) -} - - /** * CachedBatch is a cached batch of rows. * @@ -55,58 +42,41 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics, - override val outputOrdering: Seq[SortOrder]) - extends logical.LeafNode with MultiInstanceRelation { - - override protected def innerChildren: Seq[SparkPlan] = Seq(child) - - override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = StorageLevel.NONE, - child = child.canonicalized, - tableName = None)( - _cachedColumnBuffers, - sizeInBytesStats, - statsOfPlanToCache, - outputOrdering) + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override def producedAttributes: AttributeSet = outputSet - - @transient val partitionStatistics = new PartitionStatistics(output) + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - override def computeStats(): Statistics = { - if (sizeInBytesStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) - } else { - Statistics(sizeInBytes = sizeInBytesStats.value.longValue) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } } + _cachedColumnBuffers } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } + } } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -154,32 +124,77 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - sizeInBytesStats, + cacheBuilder)( statsOfPlanToCache, outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e73e1378d52e3..ea315fb71617c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -154,7 +154,7 @@ case class InMemoryTableScanExec( private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -163,16 +163,16 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { + relation.cachedPlan.outputPartitioning match { case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + case _ => relation.cachedPlan.outputPartitioning } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. private val stats = relation.partitionStatistics @@ -252,7 +252,7 @@ case class InMemoryTableScanExec( // within the map Partitions closure. val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 669e5f2bf4e65..81b7e18773f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} @@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } @@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 40915a102bab0..f0dfe6b76f7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9b7b316211d30..863703b15f4f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a6178..569f00c053e5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head From 396938ef02c70468e1695872f96b1e9aff28b7ea Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 12:21:55 -0700 Subject: [PATCH 380/593] [SPARK-24050][SS] Calculate input / processing rates correctly for DataSourceV2 streaming sources ## What changes were proposed in this pull request? In some streaming queries, the input and processing rates are not calculated at all (shows up as zero) because MicroBatchExecution fails to associated metrics from the executed plan of a trigger with the sources in the logical plan of the trigger. The way this executed-plan-leaf-to-logical-source attribution works is as follows. With V1 sources, there was no way to identify which execution plan leaves were generated by a streaming source. So did a best-effort attempt to match logical and execution plan leaves when the number of leaves were same. In cases where the number of leaves is different, we just give up and report zero rates. An example where this may happen is as follows. ``` val cachedStaticDF = someStaticDF.union(anotherStaticDF).cache() val streamingInputDF = ... val query = streamingInputDF.join(cachedStaticDF).writeStream.... ``` In this case, the `cachedStaticDF` has multiple logical leaves, but in the trigger's execution plan it only has leaf because a cached subplan is represented as a single InMemoryTableScanExec leaf. This leads to a mismatch in the number of leaves causing the input rates to be computed as zero. With DataSourceV2, all inputs are represented in the executed plan using `DataSourceV2ScanExec`, each of which has a reference to the associated logical `DataSource` and `DataSourceReader`. So its easy to associate the metrics to the original streaming sources. In this PR, the solution is as follows. If all the streaming sources in a streaming query as v2 sources, then use a new code path where the execution-metrics-to-source mapping is done directly. Otherwise we fall back to existing mapping logic. ## How was this patch tested? - New unit tests using V2 memory source - Existing unit tests using V1 source Author: Tathagata Das Closes #21126 from tdas/SPARK-24050. --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 9 +- .../streaming/ProgressReporter.scala | 146 +++++++++++++----- .../sql/streaming/StreamingQuerySuite.scala | 134 +++++++++++++++- 3 files changed, 245 insertions(+), 44 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d21..d2d04b68de6ab 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -563,7 +563,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("ensure stream-stream self-join generates only one offset in offset log") { + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { val topic = newTopic() testUtils.createTopic(topic, partitions = 2) require(testUtils.getLatestOffsets(Set(topic)).size === 2) @@ -587,7 +587,12 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AddKafkaData(Set(topic), 1, 2), CheckAnswer((1, 1, 1), (2, 2, 2)), AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1e5be9c12762..16ad3ef9a3d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -141,7 +143,7 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + val sourceProgress = sources.distinct.map { source => val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -207,62 +209,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + } + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will leads to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897c..390d67d1feb27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -466,7 +466,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("input row calculation with mixed batch and streaming sources") { + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -479,7 +489,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -492,6 +502,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -733,6 +858,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * From ac4ca7c4dd3ff666ec70aeb26ac84cffa557ee12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Apr 2018 13:42:44 -0700 Subject: [PATCH 381/593] [SPARK-24012][SQL][TEST][FOLLOWUP] add unit test ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/21100 ## How was this patch tested? N/A Author: Wenchen Fan Closes #21154 from cloud-fan/test. --- .../catalyst/analysis/TypeCoercionSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index fd6a3121663ed..1cc431aaf0a60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -429,6 +429,24 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { From 95a651339ec39d5753e849e578ad715be0d7c83e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 09:12:38 +0800 Subject: [PATCH 382/593] [SPARK-24069][R] Add array_min / array_max functions ## What changes were proposed in this pull request? This PR proposes to add array_max and array_min in R side too. array_max: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_max(mutated$v1))) ``` ``` array_max(v1) 1 4 2 4 3 4 4 3 5 3 6 3 ``` array_min: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, array_min(mutated$v1))) ``` ``` array_min(v1) 1 6 2 6 3 4 4 6 5 8 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21142 from HyukjinKwon/sparkr_array_min_array_max. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 27 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 9 ++++++++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 55dec177ea853..f36d462a83cb0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,8 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_max", + "array_min", "array_position", "asc", "ascii", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7b3aa05074563..ec4bd4e73c7e5 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -206,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -2992,6 +2993,32 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + #' @details #' \code{array_position}: Locates the position of the first occurrence of the given value #' in the given array. Returns NA if either of the arguments are NA. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f30ac9e4295e4..562d3399ee9c8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,14 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a384997830276..8cc2db7a140f9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,11 +1479,18 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_position(), element_at() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() + # and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 0)) From 3f1e999d3d215bb3b867bcd83ec5c799448ec730 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 26 Apr 2018 09:14:24 +0800 Subject: [PATCH 383/593] [SPARK-23849][SQL] Tests for samplingRatio of json datasource ## What changes were proposed in this pull request? Added the `samplingRatio` option to the `json()` method of PySpark DataFrame Reader. Improving existing tests for Scala API according to review of the PR: https://github.com/apache/spark/pull/20959 ## How was this patch tested? Added new test for PySpark, updated 2 existing tests according to reviews of https://github.com/apache/spark/pull/20959 and added new negative test Author: Maxim Gekk Closes #21056 from MaxGekk/json-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 8 +++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonSuite.scala | 63 ++++++++++--------- .../datasources/json/TestJsonData.scala | 12 ++++ 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bd79bc2f43e5..df176c579fc8b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -239,6 +239,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -256,7 +258,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..98fa1b54b0a17 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3018,6 +3018,14 @@ def test_sort_with_nulls_order(self): df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d640fdc530ce2..b44552f0eb17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -374,6 +374,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 70aee561ff0f6..a58dff827b92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2128,38 +2128,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } - test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - withTempPath { path => - val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), - StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) - for (i <- 0 until 100) { - if (predefinedSample.contains(i)) { - writer.write(s"""{"f1":${i.toString}}""" + "\n") - } else { - writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") - } - } - writer.close() + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } - val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) - assert(ds.schema == new StructType().add("f1", LongType)) - } + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) } - test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { - val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - if (predefinedSample.contains(i)) { - s"""{"f1":${i.toString}}""" + "\n" - } else { - s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" - } - }.toDS() - val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) - assert(ds.schema == new StructType().add("f1", LongType)) + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f04..6e9559edf8ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } From 58c55cb4a6d72d72df908e37aa63f617b3cc5587 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 12:19:20 +0900 Subject: [PATCH 384/593] [SPARK-23902][SQL] Add roundOff flag to months_between ## What changes were proposed in this pull request? HIVE-15511 introduced the `roundOff` flag in order to disable the rounding to 8 digits which is performed in `months_between`. Since this can be a computational intensive operation, skipping it may improve performances when the rounding is not needed. ## How was this patch tested? modified existing UT Author: Marco Gaido Closes #21008 from mgaido91/SPARK-23902. --- python/pyspark/sql/functions.py | 10 +++- .../expressions/datetimeExpressions.scala | 33 +++++++---- .../sql/catalyst/util/DateTimeUtils.scala | 32 ++++------ .../expressions/DateExpressionsSuite.scala | 59 +++++++++++-------- .../catalyst/util/DateTimeUtilsSuite.scala | 30 +++++++--- .../org/apache/spark/sql/functions.scala | 13 +++- .../apache/spark/sql/DateFunctionsSuite.scala | 7 +++ 7 files changed, 118 insertions(+), 66 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index de53b48b6f3b4..38ae41a5dafe6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1088,16 +1088,20 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ Returns the number of months between date1 and date2. + Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b9b2cd5bdb9f0..d882d06cfd625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1156,38 +1156,49 @@ case class AddMonths(startDate: Expression, numMonths: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. + The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c85..4b00a61c6cf91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -870,24 +870,14 @@ object DateTimeUtils { * If time1 and time2 having the same day of month, or both are the last day of month, * it returns an integer (time under a day will be ignored). * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * If `roundOff` is set to true, the result is rounded to 8 decimal places. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -906,8 +896,12 @@ object DateTimeUtils { val timeInDay2 = millis2 - daysToMillis(date2, timeZone) val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff + } } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 080ec487cfa6a..63b24fb9eb13a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -464,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa3..cbf6106697f30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f057310f89b..f1587cd032adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2691,11 +2691,22 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. + * The result is rounded off to 8 digits. * @group datetime_funcs * @since 1.5.0 */ def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + new MonthsBetween(date1.expr, date2.expr) + } + + /** + * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf6..f712baa7a9134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -327,6 +327,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { From cd10f9df8284ee8a5d287b2cd204c70b8ba87f5e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 13:37:13 +0900 Subject: [PATCH 385/593] [SPARK-23916][SQL] Add array_join function ## What changes were proposed in this pull request? The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one. The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21011 from mgaido91/SPARK-23916. --- python/pyspark/sql/functions.py | 21 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 169 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 35 ++++ .../org/apache/spark/sql/functions.scala | 19 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 23 +++ 6 files changed, 268 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38ae41a5dafe6..ad4bd6f5089e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,27 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + @since(1.5) @ignore_unicode_prefix def concat(*cols): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6afcf309bd690..6bc7b4e4f7cb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -401,6 +401,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bc71b5f34ce4a..90223b9126555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -378,6 +378,175 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + s""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + s""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + +} + /** * Returns the minimum value in the array. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b49fa76b2a781..7048d93fd5649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f1587cd032adc..25afaacc38d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,25 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) + } + /** * Concatenates multiple input columns together into a single column. * The function works with strings, binary and compatible array columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03605c30036a3..c216d1322a06c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + } + test("array_min function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From ffaf0f9fd407aeba7006f3d785ea8a0e51187357 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 26 Apr 2018 13:27:33 +0800 Subject: [PATCH 386/593] [SPARK-24062][THRIFT SERVER] Fix SASL encryption cannot enabled issue in thrift server ## What changes were proposed in this pull request? For the details of the exception please see [SPARK-24062](https://issues.apache.org/jira/browse/SPARK-24062). The issue is: Spark on Yarn stores SASL secret in current UGI's credentials, this credentials will be distributed to AM and executors, so that executors and drive share the same secret to communicate. But STS/Hive library code will refresh the current UGI by UGI's loginFromKeytab() after Spark application is started, this will create a new UGI in the current driver's context with empty tokens and secret keys, so secret key is lost in the current context's UGI, that's why Spark driver throws secret key not found exception. In Spark 2.2 code, Spark also stores this secret key in SecurityManager's class variable, so even UGI is refreshed, the secret is still existed in the object, so STS with SASL can still be worked in Spark 2.2. But in Spark 2.3, we always search key from current UGI, which makes it fail to work in Spark 2.3. To fix this issue, there're two possible solutions: 1. Fix in STS/Hive library, when a new UGI is refreshed, copy the secret key from original UGI to the new one. The difficulty is that some codes to refresh the UGI is existed in Hive library, which makes us hard to change the code. 2. Roll back the logics in SecurityManager to match Spark 2.2, so that this issue can be fixed. 2nd solution seems a simple one. So I will propose a PR with 2nd solution. ## How was this patch tested? Verified in local cluster. CC vanzin tgravescs please help to review. Thanks! Author: jerryshao Closes #21138 from jerryshao/SPARK-24062. --- .../main/scala/org/apache/spark/SecurityManager.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 09ec8932353a0..dbfd5a514c189 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -89,6 +89,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -321,6 +322,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -364,8 +371,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - val secretStr = HashCodes.fromBytes(secretBytes).toString() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) + secretKey = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From d1eb8d3ddc877958512194cc8f5dd8119b41bed0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 23:24:05 -0700 Subject: [PATCH 387/593] [SPARK-24094][SS][MINOR] Change description strings of v2 streaming sources to reflect the change ## What changes were proposed in this pull request? This makes it easy to understand at runtime which version is running. Great for debugging production issues. ## How was this patch tested? Not necessary. Author: Tathagata Das Closes #21160 from tdas/SPARK-24094. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../streaming/sources/RateStreamMicroBatchReader.scala | 2 +- .../apache/spark/sql/execution/streaming/sources/socket.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 2ed49ba3f5495..cbe655f9bff1f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -169,7 +169,7 @@ private[kafka010] class KafkaMicroBatchReader( kafkaOffsetReader.close() } - override def toString(): String = s"Kafka[$kafkaOffsetReader]" + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" /** * Read initial partition offsets from the checkpoint, or decide the offsets and write them to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 6cf8520fc544f..f54291bea6678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -177,7 +177,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: override def stop(): Unit = {} - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b463398..90f4a5ba4234d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -214,7 +214,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def toString: String = s"TextSocket[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } class TextSocketSourceProvider extends DataSourceV2 From ce2f919f8df1b794ceaa23e1a59d5d541ed47bf5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 26 Apr 2018 19:07:13 +0800 Subject: [PATCH 388/593] [SPARK-23799][SQL][FOLLOW-UP] FilterEstimation.evaluateInSet produces wrong stats for STRING ## What changes were proposed in this pull request? `colStat.min` AND `colStat.max` are empty for string type. Thus, `evaluateInSet` should not return zero when either `colStat.min` or `colStat.max`. ## How was this patch tested? Added a test case. Author: gatorsmile Closes #21147 from gatorsmile/cached. --- .../logical/statsEstimation/FilterEstimation.scala | 12 ++++++++---- .../statsEstimation/FilterEstimationSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 263c9ba60d145..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,13 +392,13 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv - if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { - return Some(0.0) - } - // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -422,6 +422,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 16cb5d032cf57..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -368,6 +368,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 0) } + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), From 4f1e38649ebc7710850b7c40e6fb355775e7bb7f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Apr 2018 14:21:22 -0700 Subject: [PATCH 389/593] [SPARK-24057][PYTHON] put the real data type in the AssertionError message ## What changes were proposed in this pull request? Print out the data type in the AssertionError message to make it more meaningful. ## How was this patch tested? I manually tested the changed code on my local, but didn't add any test. Author: Huaxin Gao Closes #21159 from huaxingao/spark-24057. --- python/pyspark/sql/types.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..3cd7a2ef115af 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -289,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -343,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -402,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name From f7435bec6a9348cfbbe26b13c230c08545d16067 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 15:11:42 -0700 Subject: [PATCH 390/593] [SPARK-24044][PYTHON] Explicitly print out skipped tests from unittest module ## What changes were proposed in this pull request? This PR proposes to remove duplicated dependency checking logics and also print out skipped tests from unittests. For example, as below: ``` Skipped tests in pyspark.sql.tests with pypy: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ... Skipped tests in pyspark.sql.tests with python3: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ... ``` Currently, it's not printed out in the console. I think we should better print out skipped tests in the console. ## How was this patch tested? Manually tested. Also, fortunately, Jenkins has good environment to test the skipped output. Author: hyukjinkwon Closes #21107 from HyukjinKwon/skipped-tests-print. --- python/pyspark/ml/tests.py | 16 +++-- python/pyspark/mllib/tests.py | 4 +- python/pyspark/sql/tests.py | 51 +++++++------ python/pyspark/streaming/tests.py | 4 +- python/pyspark/tests.py | 12 +--- python/run-tests.py | 115 +++++++++++++----------------- 6 files changed, 98 insertions(+), 104 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..093593132e56d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2136,17 +2136,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -2662,6 +2668,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f1088..14d788b0bef60 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -1767,9 +1767,9 @@ def test_pca(self): if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98fa1b54b0a17..6b28c557a803e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3096,23 +3096,28 @@ def setUpClass(cls): filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - @classmethod def tearDownClass(cls): - cls.spark.stop() + if hasattr(cls, "spark"): + cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() @@ -3196,18 +3201,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -5316,6 +5325,6 @@ def test_invalid_args(self): if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 103940923dd4d..d77f1baa1f344 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1590,11 +1590,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9111dbbed5929..8392d7f29af53 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2353,15 +2353,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/run-tests.py b/python/run-tests.py index 6b41b5ee22814..f408fc5082b3d 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -32,6 +32,7 @@ else: import queue as Queue from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -50,6 +51,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -109,8 +111,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -152,65 +180,17 @@ def parse_opts(): return opts -def _check_dependencies(python_exec, modules_to_test): - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - try: - subprocess_check_output( - [python_exec, "-c", "import coverage"], - stderr=open(os.devnull, 'w')) - except: - print_red("Coverage is not installed in Python executable '%s' " - "but 'COVERAGE_PROCESS_START' environment variable is set, " - "exiting." % python_exec) - sys.exit(-1) - - # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and - # explicitly prints out. See SPARK-23300. - if pyspark_sql in modules_to_test: - # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. - minimum_pyarrow_version = '0.8.0' - minimum_pandas_version = '0.19.2' - - try: - pyarrow_version = subprocess_check_output( - [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): - LOGGER.info("Will test PyArrow related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) - except: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) - - try: - pandas_version = subprocess_check_output( - [python_exec, "-c", "import pandas; print(pandas.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): - LOGGER.info("Will test Pandas related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) - except: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) def main(): @@ -237,9 +217,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - # Check if the python executable has proper dependencies installed to run tests - # for given modules properly. - _check_dependencies(python_exec, modules_to_test) + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -281,6 +262,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() From 9ee9fcf5223efdf7543161b7bc99131111876b92 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 26 Apr 2018 15:38:11 -0700 Subject: [PATCH 391/593] [SPARK-24083][YARN] Log stacktrace for uncaught exception ## What changes were proposed in this pull request? Log stacktrace for uncaught exception ## How was this patch tested? UT and manually test Author: zhoukang Closes #21151 from caneGuy/zhoukang/log-stacktrace. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d04989e138f83..650840045361c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -308,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) } } From 8aa1d7b0ede5115297541d29eab4ce5f4fe905cb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 27 Apr 2018 11:00:41 +0800 Subject: [PATCH 392/593] [SPARK-23355][SQL] convertMetastore should not ignore table properties ## What changes were proposed in this pull request? Previously, SPARK-22158 fixed for `USING hive` syntax. This PR aims to fix for `STORED AS` syntax. Although the test case covers ORC part, the patch considers both `convertMetastoreOrc` and `convertMetastoreParquet`. ## How was this patch tested? Pass newly added test cases. Author: Dongjoon Hyun Closes #20522 from dongjoon-hyun/SPARK-22158-2. --- .../spark/sql/hive/HiveStrategies.scala | 17 +++- .../sql/hive/CompressionCodecSuite.scala | 7 +- .../sql/hive/execution/HiveDDLSuite.scala | 81 +++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8df05cbb20361..a0c197b06ddab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d10a6f25c64fc..4550d350f6db2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -268,12 +268,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - // For non-partitioned table and when convertMetastore is true, Expect session-level - // take effect, and in other cases expect table-level take effect - // TODO: It should always be table-level taking effect when the bug(SPARK-22926) - // is fixed - val expectCodec = - if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + val expectCodec = tableCodec.get assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c85db78c732de..daac6af9b557f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METAS import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -2144,6 +2145,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") From 109935fc5d8b3d381bb1b09a4a570040a0a1846f Mon Sep 17 00:00:00 2001 From: eric-maynard Date: Fri, 27 Apr 2018 15:25:07 +0800 Subject: [PATCH 393/593] [SPARK-23830][YARN] added check to ensure main method is found ## What changes were proposed in this pull request? When a user specifies the wrong class -- or, in fact, a class instead of an object -- Spark throws an NPE which is not useful for debugging. This was reported in [SPARK-23830](https://issues.apache.org/jira/browse/SPARK-23830). This PR adds a check to ensure the main method was found and logs a useful error in the event that it's null. ## How was this patch tested? * Unit tests + Manual testing * The scope of the changes is very limited Author: eric-maynard Author: Eric Maynard Closes #21168 from eric-maynard/feature/SPARK-23830. --- .../spark/deploy/yarn/ApplicationMaster.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 650840045361c..595077e7e809f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Modifier} import java.net.{Socket, URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -675,9 +675,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { From 2824f12b8bac5d86a82339d4dfb4d2625e978a15 Mon Sep 17 00:00:00 2001 From: Patrick McGloin Date: Fri, 27 Apr 2018 23:04:14 +0800 Subject: [PATCH 394/593] [SPARK-23565][SS] New error message for structured streaming sources assertion ## What changes were proposed in this pull request? A more informative message to tell you why a structured streaming query cannot continue if you have added more sources, than there are in the existing checkpoint offsets. ## How was this patch tested? I added a Unit Test. Author: Patrick McGloin Closes #20946 from patrickmcgloin/master. --- .../org/apache/spark/sql/execution/streaming/OffsetSeq.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73945b39b8967..787174481ff08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -39,7 +39,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } From 3fd297af6dc568357c97abf86760c570309d6597 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 Apr 2018 11:43:29 -0700 Subject: [PATCH 395/593] [SPARK-24085][SQL] Query returns UnsupportedOperationException when scalar subquery is present in partitioning expression ## What changes were proposed in this pull request? In this case, the partition pruning happens before the planning phase of scalar subquery expressions. For scalar subquery expressions, the planning occurs late in the cycle (after the physical planning) in "PlanSubqueries" just before execution. Currently we try to execute the scalar subquery expression as part of partition pruning and fail as it implements Unevaluable. The fix attempts to ignore the Subquery expressions from partition pruning computation. Another option can be to somehow plan the subqueries before the partition pruning. Since this may not be a commonly occuring expression, i am opting for a simpler fix. Repro ``` SQL CREATE TABLE test_prc_bug ( id_value string ) partitioned by (id_type string) location '/tmp/test_prc_bug' stored as parquet; insert into test_prc_bug values ('1','a'); insert into test_prc_bug values ('2','a'); insert into test_prc_bug values ('3','b'); insert into test_prc_bug values ('4','b'); select * from test_prc_bug where id_type = (select 'b'); ``` ## How was this patch tested? Added test in SubquerySuite and hive/SQLQuerySuite Author: Dilip Biswal Closes #21174 from dilipbiswal/spark-24085. --- .../datasources/FileSourceStrategy.scala | 5 ++- .../PruneFileSourcePartitions.scala | 4 ++- .../org/apache/spark/sql/SubquerySuite.scala | 15 +++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..0a568d6b8adce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -76,7 +76,10 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f0..16b2367bfdd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 31e8b0e8dede0..acef62d81ee12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -955,4 +955,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 73f83d593bbfb..704a410b6a37b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2156,4 +2156,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } + } + } + } + } From 8614edd445264007144caa6743a8c2ca2b5082e0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 27 Apr 2018 14:14:28 -0700 Subject: [PATCH 396/593] [SPARK-24104] SQLAppStatusListener overwrites metrics onDriverAccumUpdates instead of updating them ## What changes were proposed in this pull request? Event `SparkListenerDriverAccumUpdates` may happen multiple times in a query - e.g. every `FileSourceScanExec` and `BroadcastExchangeExec` call `postDriverMetricUpdates`. In Spark 2.2 `SQLListener` updated the map with new values. `SQLAppStatusListener` overwrites it. Unless `update` preserved it in the KV store (dependant on `exec.lastWriteTime`), only the metrics from the last operator that does `postDriverMetricUpdates` are preserved. ## How was this patch tested? Unit test added. Author: Juliusz Sompolski Closes #21171 from juliuszsompolski/SPARK-24104. --- .../execution/ui/SQLAppStatusListener.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2b6bb48467eb3..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index f3f08839c1d3a..02df45d1b7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -443,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -466,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -562,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } From 1fb46f30f83e4751169ff288ad406f26b7c11f7e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 28 Apr 2018 09:55:56 +0800 Subject: [PATCH 397/593] [SPARK-23688][SS] Refactor tests away from rate source ## What changes were proposed in this pull request? Replace rate source with memory source in continuous mode test suite. Keep using "rate" source if the tests intend to put data periodically in background, or need to put short source name to load, since "memory" doesn't have provider for source. ## How was this patch tested? Ran relevant test suite from IDE. Author: Jungtaek Lim Closes #21152 from HeartSaVioR/SPARK-23688. --- .../continuous/ContinuousSuite.scala | 163 +++++++----------- 1 file changed, 61 insertions(+), 102 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c318b951ff992..5f222e7885994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -75,73 +75,50 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -149,15 +126,11 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -165,58 +138,43 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("subquery alias") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .createOrReplaceTempView("rate") - val test = spark.sql("select value from rate where value > 5") + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") - testStream(test, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } test("task failure kills the query") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() // Get an arbitrary task from this query to kill. It doesn't matter which one. var taskId: Long = -1 @@ -227,9 +185,9 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.addSparkListener(listener) try { - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(100)), - Execute(waitForRateSourceTriggers(_, 2)), + AddData(input, 0, 1, 2, 3), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -252,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") From ad94e8592b2e8f4c1bdbd958e110797c6658af84 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 28 Apr 2018 10:47:43 +0800 Subject: [PATCH 398/593] [SPARK-23736][SQL][FOLLOWUP] Error message should contains SQL types ## What changes were proposed in this pull request? In the error messages we should return the SQL types (like `string` rather than the internal types like `StringType`). ## How was this patch tested? added UT Author: Marco Gaido Closes #21181 from mgaido91/SPARK-23736_followup. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 90223b9126555..6d63a531e3b74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -863,8 +863,9 @@ case class Concat(children: Seq[Expression]) extends Expression { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + - s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c216d1322a06c..470a1c8e331ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -712,6 +712,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df.selectExpr("concat(i1, array(i1, i2))") } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) } test("flatten function") { From 4df51361a5ff1fba20524f1b580f4049b328ed32 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 28 Apr 2018 16:57:41 +0800 Subject: [PATCH 399/593] [SPARK-22732][SS][FOLLOW-UP] Fix MemorySinkV2 toString error ## What changes were proposed in this pull request? Fix `MemorySinkV2` toString() error ## How was this patch tested? N/A Author: Yuming Wang Closes #21170 from wangyum/SPARK-22732. --- .../spark/sql/execution/streaming/sources/memoryV2.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb2..d871d37ad37c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -96,7 +96,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -107,7 +107,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.clear() } - override def toString(): String = "MemorySink" + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} @@ -175,7 +175,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { private val sizePerRow = output.map(_.dataType.defaultSize).sum From bd14da6fd5a77cc03efff193a84ffccbe892cc13 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 29 Apr 2018 11:25:31 +0800 Subject: [PATCH 400/593] [SPARK-23094][SPARK-23723][SPARK-23724][SQL] Support custom encoding for json files ## What changes were proposed in this pull request? I propose new option for JSON datasource which allows to specify encoding (charset) of input and output files. Here is an example of using of the option: ``` spark.read.schema(schema) .option("multiline", "true") .option("encoding", "UTF-16LE") .json(fileName) ``` If the option is not specified, charset auto-detection mechanism is used by default. The option can be used for saving datasets to jsons. Currently Spark is able to save datasets into json files in `UTF-8` charset only. The changes allow to save data in any supported charset. Here is the approximate list of supported charsets by Oracle Java SE: https://docs.oracle.com/javase/8/docs/technotes/guides/intl/encoding.doc.html . An user can specify the charset of output jsons via the charset option like `.option("charset", "UTF-16BE")`. By default the output charset is still `UTF-8` to keep backward compatibility. The solution has the following restrictions for per-line mode (`multiline = false`): - If charset is different from UTF-8, the lineSep option must be specified. The option required because Hadoop LineReader cannot detect the line separator correctly. Here is the ticket for solving the issue: https://issues.apache.org/jira/browse/SPARK-23725 - Encoding with [BOM](https://en.wikipedia.org/wiki/Byte_order_mark) are not supported. For example, the `UTF-16` and `UTF-32` encodings are blacklisted. The problem can be solved by https://github.com/MaxGekk/spark-1/pull/2 ## How was this patch tested? I added the following tests: - reads an json file in `UTF-16LE` encoding with BOM in `multiline` mode - read json file by using charset auto detection (`UTF-32BE` with BOM) - read json file using of user's charset (`UTF-16LE`) - saving in `UTF-32BE` and read the result by standard library (not by Spark) - checking that default charset is `UTF-8` - handling wrong (unsupported) charset Author: Maxim Gekk Author: Maxim Gekk Closes #20937 from MaxGekk/json-encoding-line-sep. --- python/pyspark/sql/readwriter.py | 15 +- python/pyspark/sql/tests.py | 7 + .../sql/people_array_utf16le.json | Bin 0 -> 182 bytes .../catalyst/json/CreateJacksonParser.scala | 49 +++- .../spark/sql/catalyst/json/JSONOptions.scala | 39 ++- .../sql/catalyst/json/JacksonParser.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 3 + .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../datasources/json/JsonDataSource.scala | 60 +++-- .../datasources/json/JsonFileFormat.scala | 10 +- .../datasources/text/TextOptions.scala | 18 +- .../src/test/resources/test-data/utf16LE.json | Bin 0 -> 98 bytes .../resources/test-data/utf16WithBOM.json | Bin 0 -> 200 bytes .../resources/test-data/utf32BEWithBOM.json | Bin 0 -> 172 bytes .../datasources/json/JsonBenchmarks.scala | 179 +++++++++++++ .../datasources/json/JsonSuite.scala | 245 +++++++++++++++++- 16 files changed, 599 insertions(+), 44 deletions(-) create mode 100644 python/test_support/sql/people_array_utf16le.json create mode 100644 sql/core/src/test/resources/test-data/utf16LE.json create mode 100644 sql/core/src/test/resources/test-data/utf16WithBOM.json create mode 100644 sql/core/src/test/resources/test-data/utf32BEWithBOM.json create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index df176c579fc8b..6811fa6b3b156 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +238,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. @@ -259,7 +264,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio) + samplingRatio=samplingRatio, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -752,7 +757,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None): + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -776,6 +781,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. @@ -784,7 +791,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep) + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6b28c557a803e..e0cd2aa41a2d0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -685,6 +685,13 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_linesep_json(self): df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") expected = [Row(_corrupt_record=None, name=u'Michael'), diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000000000000000000000000000000..9c657fa30ac9c651076ff8aa3676baa400b121fb GIT binary patch literal 182 zcma!M;9^h!!fGfDVk require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } - // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) - // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None), it will be detected automatically + * when the multiLine option is set to `true`. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map { enc => + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) + val isBlacklisted = blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + | ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = !(multiLine == false && + Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") /** Sets config options on a Jackson [[JsonFactory]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..a5a4a13eb608b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -361,6 +361,14 @@ class JacksonParser( // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b44552f0eb17b..6b2ea6c06d3ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -372,6 +372,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
  • *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bbc063148a72c..e183fa6f9542b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,8 +518,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.4.0 @@ -589,8 +590,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5769c09c9a1d9..983a5f0dcade2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,26 +92,30 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset( - sparkSession, inputPaths, parsedOptions.lineSeparator) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - lineSeparator: Option[String]): Dataset[String] = { - val textOptions = lineSeparator.map { lineSep => - Map(TextOptions.LINE_SEPARATOR -> lineSep) - }.getOrElse(Map.empty[String, String]) - + parsedOptions: JSONOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, @@ -129,8 +133,12 @@ object TextInputJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -153,7 +161,11 @@ object MultiLineJsonDataSource extends JsonDataSource { parsedOptions: JSONOptions): StructType = { val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( @@ -175,11 +187,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - val path = new Path(record.getPath()) - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -194,9 +213,12 @@ object MultiLineJsonDataSource extends JsonDataSource { UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..3b04510d29695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -151,7 +153,13 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 5c1a35434f7b5..e4e201995faa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} @@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean - private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => - require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") - sep + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep } + // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } val lineSeparatorInWrite: Array[Byte] = lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } @@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 0000000000000000000000000000000000000000..ce4117fd299dfcbc7089e7c0530098bfcaf5a27e GIT binary patch literal 98 zcmbi20w;GhFpeJpqLd9J2PYe#WR62N(?$cl`!==KvkHk Roq(bsb5ek+xfp7J7y!-s4k`cu literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 0000000000000000000000000000000000000000..cf4d29328b860ffe8288edea437222c6d432a100 GIT binary patch literal 200 zcmezWFPedufr~)_3ac5E7}6Lr8HyN+8A=%Z7!nzB8B&2_RzU2`kO36W1j;Be=m6C# tG2{T{G1WN%ML{N{09DiiRT68y3qw9bDMLB|(}RGj@}XvfOpXPc4*2#AY;xCDs(fH)C|bAdP&h(YSCptLiP&H!SNdXPSl f9+12a5Gz30IY1hupBVF;plV@mNP(JB3#7RK --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a58dff827b92d..0db688fec9a67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets +import java.io.{File, FileOutputStream, StringWriter} +import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -48,6 +48,10 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2167,4 +2171,241 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val sampled = spark.read.option("samplingRatio", 1.0).json(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + } + } + + assert(exception.getMessage == encoding) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read + } else { + spark.read.schema(schema) + } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } + + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } } From 56f501e1c0cec3be7d13008bd2c0182ec83ed2a2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Apr 2018 09:40:46 +0800 Subject: [PATCH 401/593] [MINOR][DOCS] Fix a broken link for Arrow's supported types in the programming guide ## What changes were proposed in this pull request? This PR fixes a broken link for Arrow's supported types in the programming guide. ## How was this patch tested? Manually tested via `SKIP_API=1 jekyll watch`. "Supported SQL Types" here in https://spark.apache.org/docs/latest/sql-programming-guide.html#enabling-for-conversion-tofrom-pandas is broken. It should be https://spark.apache.org/docs/latest/sql-programming-guide.html#supported-sql-types Author: hyukjinkwon Closes #21191 from HyukjinKwon/minor-arrow-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8ff1470970f7..836ce990205a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1703,7 +1703,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) From 3121b411f748859ed3ed1c97cbc21e6ae980a35c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 30 Apr 2018 09:45:22 +0800 Subject: [PATCH 402/593] [SPARK-23846][SQL] The samplingRatio option for CSV datasource ## What changes were proposed in this pull request? I propose to support the `samplingRatio` option for schema inferring of CSV datasource similar to the same option of JSON datasource: https://github.com/apache/spark/blob/b14993e1fcb68e1c946a671c6048605ab4afdf58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala#L49-L50 ## How was this patch tested? Added 2 tests for json and 2 tests for csv datasources. The tests checks that only subset of input dataset is used for schema inferring. Author: Maxim Gekk Author: Maxim Gekk Closes #20959 from MaxGekk/csv-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 7 +++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/csv/CSVOptions.scala | 3 ++ .../execution/datasources/csv/CSVUtils.scala | 28 +++++++++++ .../execution/datasources/csv/CSVSuite.scala | 47 ++++++++++++++++++- .../datasources/csv/TestCsvData.scala | 36 ++++++++++++++ 8 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6811fa6b3b156..9899eb5058b82 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -345,7 +345,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -428,6 +429,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -446,7 +449,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e0cd2aa41a2d0..bc3eaf16b4de7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3033,6 +3033,13 @@ def test_json_sampling_ratio(self): .json(rdd).schema self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6b2ea6c06d3ae..53f44888ebaff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -539,6 +539,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `header` (default `false`): uses the first line as names of columns.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • *
  • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
  • *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4870d75fc5f08..bc1f4ab3bb053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -161,7 +161,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -235,7 +236,8 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce17..2ec0fc605a84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -150,6 +150,9 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092ca..31464f1bcc68e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -131,4 +134,29 @@ object CSVUtils { schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV dataset as configured by `samplingRatio`. + */ + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d9217..461abdd96d3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -1279,4 +1278,48 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 0000000000000..3e20cc47dca2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} From b42ad165bb93c96cc5be9ed05b5026f9baafdfa2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Apr 2018 09:13:32 -0700 Subject: [PATCH 403/593] [SPARK-24072][SQL] clearly define pushed filters ## What changes were proposed in this pull request? filters like parquet row group filter, which is actually pushed to the data source but still to be evaluated by Spark, should also count as `pushedFilters`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21143 from cloud-fan/step1. --- .../SupportsPushDownCatalystFilters.java | 11 +++- .../v2/reader/SupportsPushDownFilters.java | 10 ++- .../datasources/v2/DataSourceV2Relation.scala | 63 +++++++++++-------- .../datasources/v2/DataSourceV2ScanExec.scala | 1 + .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/DataSourceV2StringFormat.scala | 19 ++---- .../v2/PushDownOperatorsToDataSource.scala | 6 +- .../continuous/ContinuousSuite.scala | 2 +- 8 files changed, 68 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 290d614805ac7..4543c143a9aca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -39,7 +39,16 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { Expression[] pushCatalystFilters(Expression[] filters); /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * Returns the catalyst filters that are pushed to the data source via + * {@link #pushCatalystFilters(Expression[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for * this case. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 1cff024232a44..b6a90a3d0b681 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -37,7 +37,15 @@ public interface SupportsPushDownFilters extends DataSourceReader { Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2b282ffae2390..90fb5a14c9fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -54,9 +54,12 @@ case class DataSourceV2Relation( private lazy val v2Options: DataSourceOptions = makeV2Options(options) + // postScanFilters: filters that need to be evaluated after the scan. + // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. + // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. lazy val ( reader: DataSourceReader, - unsupportedFilters: Seq[Expression], + postScanFilters: Seq[Expression], pushedFilters: Seq[Expression]) = { val newReader = userSpecifiedSchema match { case Some(s) => @@ -67,14 +70,16 @@ case class DataSourceV2Relation( DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - val (remainingFilters, pushedFilters) = filters match { + val (postScanFilters, pushedFilters) = filters match { case Some(filterSeq) => DataSourceV2Relation.pushFilters(newReader, filterSeq) case _ => (Nil, Nil) } + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - (newReader, remainingFilters, pushedFilters) + (newReader, postScanFilters, pushedFilters) } override def doCanonicalize(): LogicalPlan = { @@ -121,6 +126,8 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Nil + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. @@ -217,31 +224,35 @@ object DataSourceV2Relation { reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { - case catalystFilterSupport: SupportsPushDownCatalystFilters => - ( - catalystFilterSupport.pushCatalystFilters(filters.toArray), - catalystFilterSupport.pushedCatalystFilters() - ) - - case filterSupport: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet - val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => - unhandledFilters.contains(f) + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (postScanFilters, pushedFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } } - (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + + (untranslatableExprs ++ postScanFilters, pushedFilters) case _ => (filters, Nil) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e142..41bdda47c8c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -41,6 +41,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c2a31442d2be5..1b7c639f10f98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDat object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index aed55a429bfd7..693e67dcd108e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.util.Utils /** @@ -49,16 +47,9 @@ trait DataSourceV2StringFormat { def options: Map[String, String] /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. + * The filters which have been pushed to the data source. */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() @@ -68,8 +59,8 @@ trait DataSourceV2StringFormat { def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) { - entries += "Filters" -> filters.mkString("[", ", ", "]") + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") } // TODO: we should only display some standard options like path, table, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index f23d228567241..9293d4f831bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -57,9 +57,9 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { projection = projection.asInstanceOf[Seq[AttributeReference]], filters = Some(filters)) - // Add a Filter for any filters that could not be pushed - val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) - val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) + // Add a Filter for any filters that need to be evaluated after scan. + val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) + val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) // Add a Project to ensure the output matches the required projection if (newRelation.output != projectAttrs) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 5f222e7885994..cd1704ac2fdad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 007ae6878f4b4defe1f08114212fa7289fc9ee4a Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Mon, 30 Apr 2018 13:40:03 -0700 Subject: [PATCH 404/593] [SPARK-24003][CORE] Add support to provide spark.executor.extraJavaOptions in terms of App Id and/or Executor Id's ## What changes were proposed in this pull request? Added support to specify the 'spark.executor.extraJavaOptions' value in terms of the `{{APP_ID}}` and/or `{{EXECUTOR_ID}}`, `{{APP_ID}}` will be replaced by Application Id and `{{EXECUTOR_ID}}` will be replaced by Executor Id while starting the executor. ## How was this patch tested? I have verified this by checking the executor process command and gc logs. I verified the same in different deployment modes(Standalone, YARN, Mesos) client and cluster modes. Author: Devaraj K Closes #21088 from devaraj-kavali/SPARK-24003. --- .../spark/deploy/worker/ExecutorRunner.scala | 8 ++++++-- .../main/scala/org/apache/spark/util/Utils.scala | 15 +++++++++++++++ docs/configuration.md | 5 +++++ .../k8s/features/BasicExecutorFeatureStep.scala | 4 +++- .../MesosCoarseGrainedSchedulerBackend.scala | 4 +++- .../mesos/MesosFineGrainedSchedulerBackend.scala | 4 +++- .../org/apache/spark/deploy/yarn/Client.scala | 8 ++++++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 ++- 8 files changed, 43 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc8204..dc6a3076a5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d2be93226e2a2..dcad1b914038f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2689,6 +2689,21 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } } private[util] object CallerContext extends Logging { diff --git a/docs/configuration.md b/docs/configuration.md index fb02d7ea1d4ea..8a1aacef85760 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -328,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc
  • diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d22097587aafe..529069d3b8a0c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -89,7 +89,9 @@ private[spark] class BasicExecutorFeatureStep( val executorExtraJavaOptionsEnv = kubernetesConf .get(EXECUTOR_JAVA_OPTIONS) .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca486..9b75e4c98344a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d246109..71a70ff048ccc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5763c3dbc5a8a..bafb129032b49 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -892,7 +892,9 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten @@ -914,7 +916,9 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab08698035c98..a2a18cdff65af 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -141,7 +141,8 @@ private[yarn] class ExecutorRunnable( // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) From b857fb549f3bf4e6f289ba11f3903db0a3696dec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 May 2018 09:06:23 +0800 Subject: [PATCH 405/593] [SPARK-23853][PYSPARK][TEST] Run Hive-related PySpark tests only for `-Phive` ## What changes were proposed in this pull request? When `PyArrow` or `Pandas` are not available, the corresponding PySpark tests are skipped automatically. Currently, PySpark tests fail when we are not using `-Phive`. This PR aims to skip Hive related PySpark tests when `-Phive` is not given. **BEFORE** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql File "/Users/dongjoon/spark/python/pyspark/sql/readwriter.py", line 295, in pyspark.sql.readwriter.DataFrameReader.table ... IllegalArgumentException: u"Error while instantiating 'org.apache.spark.sql.hive.HiveExternalCatalog':" ********************************************************************** 1 of 3 in pyspark.sql.readwriter.DataFrameReader.table ***Test Failed*** 1 failures. ``` **AFTER** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql ... Tests passed in 138 seconds Skipped tests in pyspark.sql.tests with python2.7: ... test_hivecontext (pyspark.sql.tests.HiveSparkSubmitTests) ... skipped 'Hive is not available.' ``` ## How was this patch tested? This is a test-only change. First, this should pass the Jenkins. Then, manually do the following. ```bash build/mvn -DskipTests clean package python/run-tests.py --python-executables python2.7 --modules pyspark-sql ``` Author: Dongjoon Hyun Closes #21141 from dongjoon-hyun/SPARK-23853. --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/tests.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9899eb5058b82..448a4732001b5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -979,7 +979,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bc3eaf16b4de7..cc6acfdb07d99 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3043,6 +3043,26 @@ def test_csv_sampling_ratio(self): class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by From 7bbec0dced35aeed79c1a24b6f7a1e0a3508b0fb Mon Sep 17 00:00:00 2001 From: wangyanlin01 Date: Tue, 1 May 2018 16:22:52 +0800 Subject: [PATCH 406/593] [SPARK-24061][SS] Add TypedFilter support for continuous processing ## What changes were proposed in this pull request? Add TypedFilter support for continuous processing application. ## How was this patch tested? unit tests Author: wangyanlin01 Closes #21136 from yanlin-Lynn/SPARK-24061. --- .../UnsupportedOperationChecker.scala | 3 ++- .../analysis/UnsupportedOperationsSuite.scala | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dded..d3d6c636c4ba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,8 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda264..cb487c8893541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } From 6782359a04356e4cde32940861bf2410ef37f445 Mon Sep 17 00:00:00 2001 From: Bounkong Khamphousone Date: Tue, 1 May 2018 08:28:21 -0700 Subject: [PATCH 407/593] [SPARK-23941][MESOS] Mesos task failed on specific spark app name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Shell escaped the name passed to spark-submit and change how conf attributes are shell escaped. ## How was this patch tested? This test has been tested manually with Hive-on-spark with mesos or with the use case described in the issue with the sparkPi application with a custom name which contains illegal shell characters. With this PR, hive-on-spark on mesos works like a charm with hive 3.0.0-SNAPSHOT. I state that this contribution is my original work and that I license the work to the project under the project’s open source license Author: Bounkong Khamphousone Closes #21014 from tiboun/fix/SPARK-23941. --- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820a..b36f46456f9a5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -530,9 +530,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** From e15850be6e0210614a734a307f5b83bdf44e2456 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 May 2018 10:55:01 +0800 Subject: [PATCH 408/593] [SPARK-24131][PYSPARK] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? We need to determine Spark major and minor versions in PySpark. We can add a `majorMinorVersion` API to PySpark which is similar to the Scala API in `VersionUtils.majorMinorVersion`. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21203 from viirya/SPARK-24131. --- python/pyspark/util.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..04df835bf6717 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -61,6 +62,26 @@ def _get_argspec(f): return argspec +def majorMinorVersion(version): + """ + Get major and minor version numbers for given Spark version string. + + >>> version = "2.4.0" + >>> majorMinorVersion(version) + (2, 4) + + >>> version = "abc" + >>> majorMinorVersion(version) is None + True + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', version) + if m is None: + return None + else: + return (int(m.group(1)), int(m.group(2))) + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9215ee7a16b57c56ae927d65e024cf7afe542cbb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 10:41:34 +0200 Subject: [PATCH 409/593] [SPARK-23976][CORE] Detect length overflow in UTF8String.concat()/ByteArray.concat() ## What changes were proposed in this pull request? This PR detects length overflow if total elements in inputs are not acceptable. For example, when the three inputs has `0x7FFF_FF00`, `0x7FFF_FF00`, and `0xE00`, we should detect length overflow since we cannot allocate such a large structure on `byte[]`. On the other hand, the current algorithm can allocate the result structure with `0x1000`-byte length due to integer sum overflow. ## How was this patch tested? Existing UTs. If we would create UTs, we need large heap (6-8GB). It may make test environment unstable. If it is necessary to create UTs, I will create them. Author: Kazuaki Ishizaki Closes #21064 from kiszk/SPARK-23976. --- .../org/apache/spark/unsafe/types/ByteArray.java | 12 +++++++----- .../org/apache/spark/unsafe/types/UTF8String.java | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f61..ecd7c19f2c634 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e9b3d9b045af5..e91fc4391425c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,8 +29,8 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; - import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; From 152eaf6ae698cd0df7f5a5be3f17ee46e0be929d Mon Sep 17 00:00:00 2001 From: WangJinhai02 Date: Wed, 2 May 2018 22:40:14 +0800 Subject: [PATCH 410/593] [SPARK-24107][CORE] ChunkedByteBuffer.writeFully method has not reset the limit value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue: https://issues.apache.org/jira/browse/SPARK-24107?jql=text%20~%20%22ChunkedByteBuffer%22 ChunkedByteBuffer.writeFully method has not reset the limit value. When chunks larger than bufferWriteChunkSize, such as 80 * 1024 * 1024 larger than config.BUFFER_WRITE_CHUNK_SIZE(64 * 1024 * 1024),only while once, will lost 16 * 1024 * 1024 byte Author: WangJinhai02 Closes #21175 from manbuyun/bugfix-ChunkedByteBuffer. --- .../spark/util/io/ChunkedByteBuffer.scala | 13 +++++++++---- .../spark/io/ChunkedByteBufferSuite.scala | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd8..3ae8dfcc1cb66 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) + val curChunkLimit = bytes.limit() + while (bytes.hasRemaining) { + try { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + } finally { + bytes.limit(curChunkLimit) + } } } } diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b0499..2107559572d78 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) From 8dbf56c055218ff0f3fabae84b63c022f43afbfd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 11:58:55 -0700 Subject: [PATCH 411/593] [SPARK-24013][SQL] Remove unneeded compress in ApproximatePercentile ## What changes were proposed in this pull request? `ApproximatePercentile` contains a workaround logic to compress the samples since at the beginning `QuantileSummaries` was ignoring the compression threshold. This problem was fixed in SPARK-17439, but the workaround logic was not removed. So we are compressing the samples many more times than needed: this could lead to critical performance degradation. This can create serious performance issues in queries like: ``` select approx_percentile(id, array(0.1)) from range(10000000) ``` ## How was this patch tested? added UT Author: Marco Gaido Closes #21133 from mgaido91/SPARK-24013. --- .../aggregate/ApproximatePercentile.scala | 33 ++++--------------- .../sql/catalyst/util/QuantileSummaries.scala | 11 ++++--- .../sql/ApproximatePercentileQuerySuite.scala | 13 ++++++++ 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index a45854a3b5146..f1bbbdabb41f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -335,8 +314,8 @@ object ApproximatePercentile { sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index b013add9c9778..3190e511e2cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb9..d635912cf7205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } From 8bd27025b7cf0b44726b6f4020d294ef14dbbb7e Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Wed, 2 May 2018 12:43:19 -0700 Subject: [PATCH 412/593] [SPARK-24133][SQL] Check for integer overflows when resizing WritableColumnVectors ## What changes were proposed in this pull request? `ColumnVector`s store string data in one big byte array. Since the array size is capped at just under Integer.MAX_VALUE, a single `ColumnVector` cannot store more than 2GB of string data. But since the Parquet files commonly contain large blobs stored as strings, and `ColumnVector`s by default carry 4096 values, it's entirely possible to go past that limit. In such cases a negative capacity is requested from `WritableColumnVector.reserve()`. The call succeeds (requested capacity is smaller than already allocated capacity), and consequently `java.lang.ArrayIndexOutOfBoundsException` is thrown when the reader actually attempts to put the data into the array. This change introduces a simple check for integer overflow to `WritableColumnVector.reserve()` which should help catch the error earlier and provide more informative exception. Additionally, the error message in `WritableColumnVector.throwUnsupportedException()` was corrected, as it previously encouraged users to increase rather than reduce the batch size. ## How was this patch tested? New units tests were added. Author: Ala Luszczak Closes #21206 from ala/overflow-reserve. --- .../vectorized/WritableColumnVector.java | 21 ++++++++++++------- .../vectorized/ColumnarBatchSuite.scala | 7 +++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5275e4a91eac0..b0e119d658cb4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -81,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader, or increase the vectorized reader batch size. For parquet file " + - "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + - SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 772f687526008..f57f07b498261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1333,4 +1333,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.close() } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } From 504c9cfd21ef45a13d9428fef3b197dcbf6786cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 13:49:15 -0700 Subject: [PATCH 413/593] [SPARK-24123][SQL] Fix precision issues in monthsBetween with more than 8 digits ## What changes were proposed in this pull request? SPARK-23902 introduced the ability to retrieve more than 8 digits in `monthsBetween`. Unfortunately, current implementation can cause precision loss in such a case. This was causing also a flaky UT. This PR mirrors Hive's implementation in order to avoid precision loss also when more than 8 digits are returned. ## How was this patch tested? running 10000000 times the flaky UT Author: Marco Gaido Closes #21196 from mgaido91/SPARK-24123. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4b00a61c6cf91..d2fe15c48c6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -888,14 +888,19 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D if (roundOff) { // rounding to 8 digits math.round(diff * 1e8) / 1e8 From 5be8aab14468e55b1049a0c83f02dcec0651162f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 13:53:10 -0700 Subject: [PATCH 414/593] [SPARK-23923][SQL] Add cardinality function ## What changes were proposed in this pull request? The PR adds the SQL function `cardinality`. The behavior of the function is based on Presto's one. The function returns the length of the array or map stored in the column as `int` while the Presto version returns the value as `BigInt` (`long` in Spark). The discussions regarding the difference of return type are [here](https://github.com/apache/spark/pull/21031#issuecomment-381284638) and [there](https://github.com/apache/spark/pull/21031#discussion_r181622107). ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21031 from kiszk/SPARK-23923. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6bc7b4e4f7cb3..3ffbc9c8069fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 470a1c8e331ba..a5163accb1bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -341,6 +341,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(-1)) ) + + checkAnswer( + df.selectExpr("cardinality(a)"), + Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) + ) } test("map size function") { From e4c91c089a701117af82f585d14d8afc5245fc64 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 2 May 2018 16:12:21 -0700 Subject: [PATCH 415/593] [SPARK-24111][SQL] Add the TPCDS v2.7 (latest) queries in TPCDSQueryBenchmark ## What changes were proposed in this pull request? This pr added the TPCDS v2.7 (latest) queries in `TPCDSQueryBenchmark`. These query files have been added in `SPARK-23167`. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #21177 from maropu/AddTpcdsV2_7InBenchmark. --- .../benchmark/TPCDSQueryBenchmark.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9aa..abe61a2c2b9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } From bf4352ca6c96dfab16b286c54720685e32b216f1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 May 2018 09:28:14 +0800 Subject: [PATCH 416/593] [SPARK-24110][THRIFT-SERVER] Avoid UGI.loginUserFromKeytab in STS ## What changes were proposed in this pull request? Spark ThriftServer will call UGI.loginUserFromKeytab twice in initialization. This is unnecessary and will cause various potential problems, like Hadoop IPC failure after 7 days, or RM failover issue and so on. So here we need to remove all the unnecessary login logics and make sure UGI in the context never be created again. Note this is actually a HS2 issue, If later on we upgrade supported Hive version, the issue may already be fixed in Hive side. ## How was this patch tested? Local verification in secure cluster. Author: jerryshao Closes #21178 from jerryshao/SPARK-24110. --- .../hive/service/auth/HiveAuthFactory.java | 62 +++++++++++++++++-- .../thriftserver/SparkSQLCLIService.scala | 20 +++++- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade65283045..10000f12ab329 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a7..1335e16e35882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => From c9bfd1c6f8d16890ea1e5bc2bcb654a3afb32591 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 May 2018 15:15:05 +0800 Subject: [PATCH 417/593] [SPARK-23489][SQL][TEST] HiveExternalCatalogVersionsSuite should verify the downloaded file ## What changes were proposed in this pull request? Although [SPARK-22654](https://issues.apache.org/jira/browse/SPARK-22654) made `HiveExternalCatalogVersionsSuite` download from Apache mirrors three times, it has been flaky because it didn't verify the downloaded file. Some Apache mirrors terminate the downloading abnormally, the *corrupted* file shows the following errors. ``` gzip: stdin: not in gzip format tar: Child returned status 1 tar: Error is not recoverable: exiting now 22:46:32.700 WARN org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.hive.HiveExternalCatalogVersionsSuite, thread names: Keep-Alive-Timer ===== *** RUN ABORTED *** java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "/tmp/test-spark/spark-2.2.0"): error=2, No such file or directory ``` This has been reported weirdly in two ways. For example, the above case is reported as Case 2 `no failures`. - Case 1. [Test Result (1 failure / +1)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4389/) - Case 2. [Test Result (no failures)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4811/) This PR aims to make `HiveExternalCatalogVersionsSuite` more robust by verifying the downloaded `tgz` file by extracting and checking the existence of `bin/spark-submit`. If it turns out that the file is empty or corrupted, `HiveExternalCatalogVersionsSuite` will do retry logic like the download failure. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21210 from dongjoon-hyun/SPARK-23489. --- .../HiveExternalCatalogVersionsSuite.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6ca58e68d31eb..ea86ab9772bc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -67,7 +67,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } catch { case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } @@ -75,20 +89,6 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -161,7 +161,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( From 417ad92502e714da71552f64d0e1257d2fd5d3d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:27:01 +0800 Subject: [PATCH 418/593] [SPARK-23715][SQL] the input of to/from_utc_timestamp can not have timezone ## What changes were proposed in this pull request? `from_utc_timestamp` assumes its input is in UTC timezone and shifts it to the specified timezone. When the timestamp contains timezone(e.g. `2018-03-13T06:18:23+00:00`), Spark breaks the semantic and respect the timezone in the string. This is not what user expects and the result is different from Hive/Impala. `to_utc_timestamp` has the same problem. More details please refer to the JIRA ticket. This PR fixes this by returning null if the input timestamp contains timezone. ## How was this patch tested? new tests Author: Wenchen Fan Closes #21169 from cloud-fan/from_utc_timezone. --- docs/sql-programming-guide.md | 13 +- .../sql/catalyst/analysis/TypeCoercion.scala | 30 +++- .../expressions/datetimeExpressions.scala | 42 ++++++ .../sql/catalyst/util/DateTimeUtils.scala | 22 ++- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../catalyst/analysis/TypeCoercionSuite.scala | 12 +- .../resources/sql-tests/inputs/datetime.sql | 33 +++++ .../sql-tests/results/datetime.sql.out | 135 +++++++++++++++++- .../apache/spark/sql/DateFunctionsSuite.scala | 8 ++ 9 files changed, 283 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 836ce990205a9..075b953a0898e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,12 +1805,13 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 25bad28a2a209..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -776,12 +776,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -798,7 +819,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -814,6 +835,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d882d06cfd625..76aa61415a11f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1016,6 +1016,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = s""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d2fe15c48c6dd..e646da0659e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -296,10 +296,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +435,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293eca..3942240c442b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1208,6 +1208,13 @@ object SQLConf { .stringConf .createWithDefault("") + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1cc431aaf0a60..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,11 +536,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -823,7 +823,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..4950a4b7a4e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,36 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..9eede305dbdcc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 26 -- !query 0 @@ -82,9 +82,138 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f712baa7a9134..237412aa692e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -696,4 +697,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } From 991b526992bcf1dc1268578b650916569b12f583 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:56:30 +0800 Subject: [PATCH 419/593] [SPARK-24166][SQL] InMemoryTableScanExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from https://github.com/apache/spark/pull/21190 , to make it easier to backport. `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21223 from cloud-fan/minor1. --- .../execution/columnar/InMemoryTableScanExec.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index ea315fb71617c..0b4dd76c7d860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,10 +78,12 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) @@ -101,10 +103,13 @@ case class InMemoryTableScanExec( private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled if (supportsBatch) { // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] } else { val numOutputRows = longMetric("numOutputRows") From 96a50016bb0fb1cc57823a6706bff2467d671efd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 23:36:09 +0800 Subject: [PATCH 420/593] [SPARK-24169][SQL] JsonToStructs should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21226 from cloud-fan/minor4. --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/jsonExpressions.scala | 16 ++-- .../expressions/JsonExpressionsSuite.scala | 76 +++++++++---------- .../org/apache/spark/sql/functions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 +- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3ffbc9c8069fd..51bb6b0abe408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -534,7 +534,9 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted val expectedNumberOfParameters = if (validParametersCount.length == 1) { validParametersCount.head.toString } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index fdd672c416a03..34161f0f03f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -514,11 +514,10 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String], + forceNullableSchema: Boolean) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) - // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -532,14 +531,21 @@ case class JsonToStructs( schema = JsonExprUtils.validateSchemaLiteral(schema), options = Map.empty[String, String], child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.validateSchemaLiteral(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + + // Used in `org.apache.spark.sql.functions` + def this(schema: DataType, options: Map[String, String], child: Expression) = + this(schema, options, child, timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7812319756eae..00e97637eee7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,7 +512,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID)), + Option(tz.getID), + true), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -521,7 +522,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId), + gmtId, + true), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -530,7 +532,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), null ) } @@ -685,27 +687,23 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { - val input = - """{ - | "a": 1, - | "c": "foo" - |} - |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - checkEvaluation( - JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), - output - ) - val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) - .dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) - } + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs( + jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 25afaacc38d6f..d2e22fa355514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3179,7 +3179,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(schema, options, e.expr) + new JsonToStructs(schema, options, e.expr) } /** diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 581dddc89d0bb..14a69128ffb41 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 From 94641fe6cc68e5977dd8663b8f232a287a783acb Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 3 May 2018 10:59:18 -0500 Subject: [PATCH 421/593] [SPARK-23433][CORE] Late zombie task completions update all tasksets Fetch failure lead to multiple tasksets which are active for a given stage. While there is only one "active" version of the taskset, the earlier attempts can still have running tasks, which can complete successfully. So a task completion needs to update every taskset so that it knows the partition is completed. That way the final active taskset does not try to submit another task for the same partition, and so that it knows when it is completed and when it should be marked as a "zombie". Added a regression test. Author: Imran Rashid Closes #21131 from squito/SPARK-23433. --- .../spark/scheduler/TaskSchedulerImpl.scala | 14 +++ .../spark/scheduler/TaskSetManager.scala | 20 +++- .../scheduler/TaskSchedulerImplSuite.scala | 104 ++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981b..8e97b3da33820 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8a96a7692f614..195fc8025e4b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -73,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -153,7 +155,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -754,6 +756,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -764,6 +769,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7bef..33f2ea1c94e75 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } From e3201e165e41f076ec72175af246d12c0da529cf Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 3 May 2018 17:05:02 -0700 Subject: [PATCH 422/593] [SPARK-24035][SQL] SQL syntax for Pivot ## What changes were proposed in this pull request? Add SQL support for Pivot according to Pivot grammar defined by Oracle (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_clause.htm) with some simplifications, based on our existing functionality and limitations for Pivot at the backend: 1. For pivot_for_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_for_clause.htm), the column list form is not supported, which means the pivot column can only be one single column. 2. For pivot_in_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_in_clause.htm), the sub-query form and "ANY" is not supported (this is only supported by Oracle for XML anyway). 3. For pivot_in_clause, aliases for the constant values are not supported. The code changes are: 1. Add parser support for Pivot. Note that according to https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#i2076542, Pivot cannot be used together with lateral views in the from clause. This restriction has been implemented in the Parser rule. 2. Infer group-by expressions: group-by expressions are not explicitly specified in SQL Pivot clause and need to be deduced based on this rule: https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#CHDFAFIE, so we have to post-fix it at query analysis stage. 3. Override Pivot.resolved as "false": for the reason mentioned in [2] and the fact that output attributes change after Pivot being replaced by Project or Aggregate, we avoid resolving parent references until after Pivot has been resolved and replaced. 4. Verify aggregate expressions: only aggregate expressions with or without aliases can appear in the first part of the Pivot clause, and this check is performed as analysis stage. ## How was this patch tested? A new test suite PivotSuite is added. Author: maryannxue Closes #21187 from maryannxue/spark-24035. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 35 +++- .../sql/catalyst/parser/AstBuilder.scala | 20 +- .../plans/logical/basicLogicalOperators.scala | 27 ++- .../parser/TableIdentifierParserSuite.scala | 6 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 113 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 194 ++++++++++++++++++ 8 files changed, 386 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/pivot.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/pivot.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5fa75fe348e68..f7f921ec22c35 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* lateralView* + : FROM relation (',' relation)* (pivotClause | lateralView*)? ; aggregation @@ -413,6 +413,10 @@ groupingSet | expression ; +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + ; + lateralView : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? ; @@ -725,7 +729,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -745,7 +749,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -760,6 +764,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -805,6 +810,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7c..dfdcdbc1eb2c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -275,9 +275,9 @@ class Analyzer( case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -504,9 +504,20 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + // Check all aggregate expressions. + aggregates.foreach { e => + if (!isAggregateExpression(e)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$e'") + } + } + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) @@ -568,16 +579,20 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + private def isAggregateExpression(expr: Expression): Boolean = { + expr match { + case Alias(e, _) => isAggregateExpression(e) + case AggregateExpression(_, _, _, _) => true + case _ => false + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a878..64eed23884584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** @@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) + val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df504795430..3bf32ef7884e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -686,17 +686,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998d..89903c2825125 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7147798d99533..6c2be3610ae30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -73,7 +73,7 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 0000000000000..01dea6c81c11b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,113 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 0000000000000..85e3488990e20 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 13 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 2 schema +struct +-- !query 2 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 3 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 3 schema +struct +-- !query 3 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 4 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 4 schema +struct +-- !query 4 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 5 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +63000 50000 + + +-- !query 6 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 2012 50000 2012 + + +-- !query 7 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 7 schema +struct +-- !query 7 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 8 schema +struct +-- !query 8 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 10 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 10 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 10 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 11 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, found 'abs(earnings#x)'; + + +-- !query 12 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 From e646ae67f2e793204bc819ab2b90815214c2bbf3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 17:27:13 -0700 Subject: [PATCH 423/593] [SPARK-24168][SQL] WindowExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21225 from cloud-fan/minor3. --- .../spark/sql/execution/window/WindowExec.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f3996..626f39d9e95cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } From 0c23e254c38d4a9210939e1e1b0074278568abed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 09:27:14 +0800 Subject: [PATCH 424/593] [SPARK-24167][SQL] ParquetFilters should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't call `conf.parquetFilterPushDownDate` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21224 from cloud-fan/minor2. --- .../datasources/parquet/ParquetFileFormat.scala | 3 ++- .../datasources/parquet/ParquetFilters.scala | 15 +++++++-------- .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952de..d1f9e11ed4225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -352,7 +353,7 @@ class ParquetFileFormat // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ccc8306866d68..310626197a763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,14 +25,13 @@ import org.apache.parquet.io.api.Binary import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters(pushDownDate: Boolean) { private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -59,7 +58,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -85,7 +84,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -108,7 +107,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.lt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -131,7 +130,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.ltEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -154,7 +153,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -177,7 +176,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gtEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1d3476e747046..667e0b1760e3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,6 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -99,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -517,7 +519,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -525,7 +527,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -533,7 +535,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.Not( sources.And( From 7f1b6b182e3cf3cbf29399e7bfbe03fa869e0bc8 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 4 May 2018 16:02:21 +0800 Subject: [PATCH 425/593] [SPARK-24136][SS] Fix MemoryStreamDataReader.next to skip sleeping if record is available ## What changes were proposed in this pull request? Avoid unnecessary sleep (10 ms) in each invocation of MemoryStreamDataReader.next. ## How was this patch tested? Ran ContinuousSuite from IDE. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21207 from arunmahadevan/memorystream. --- .../streaming/sources/ContinuousMemoryStream.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729b..a8fca3c19a2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -183,11 +183,10 @@ class ContinuousMemoryStreamDataReader( private var current: Option[Row] = None override def next(): Boolean = { - current = None + current = getRecord while (current.isEmpty) { Thread.sleep(10) - current = endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + current = getRecord } currentOffset += 1 true @@ -199,6 +198,10 @@ class ContinuousMemoryStreamDataReader( override def getOffset: ContinuousMemoryStreamPartitionOffset = ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) From 4d5de4d303a773b1c18c350072344bd7efca9fc4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 19:20:15 +0800 Subject: [PATCH 426/593] [SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly ## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan Closes #21229 from cloud-fan/accumulator. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 ++++-- .../spark/util/AccumulatorV2Suite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 0f84ea9752cf5..2bc84953a56eb 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..fe0a9a471a651 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } From d04806a23c1843a7f0dcc4fa236ed1b40ae113a5 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 4 May 2018 13:29:47 -0700 Subject: [PATCH 427/593] =?UTF-8?q?[SPARK-24124]=20Spark=20history=20serve?= =?UTF-8?q?r=20should=20create=20spark.history.store.=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …path and set permissions properly ## What changes were proposed in this pull request? Spark history server should create spark.history.store.path and set permissions properly. Note createdDirectories doesn't do anything if the directories are already created. This does not stomp on the permissions if the user had manually created the directory before the history server could. ## How was this patch tested? Manually tested in a 100 node cluster. Ensured directories created with proper permissions. Ensured restarted worked apps/temp directories worked as apps were read. Author: Thomas Graves Closes #21234 from tgravescs/SPARK-24124. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 56db9359e033f..bf1eeb0c1bf59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,8 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -130,8 +132,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) From af4dc50280ffcdeda208ef2dc5f8b843389732e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 4 May 2018 14:14:40 -0700 Subject: [PATCH 428/593] [SPARK-24039][SS] Do continuous processing writes with multiple compute() calls ## What changes were proposed in this pull request? Do continuous processing writes with multiple compute() calls. The current strategy (before this PR) is hacky; we just call next() on an iterator which has already returned hasNext = false, knowing that all the nodes we whitelist handle this properly. This will have to be changed before we can support more complex query plans. (In particular, I have a WIP https://github.com/jose-torres/spark/pull/13 which should be able to support aggregates in a single partition with minimal additional work.) Most of the changes here are just refactoring to accommodate the new model. The behavioral changes are: * The writer now calls prev.compute(split, context) once per epoch within the epoch loop. * ContinuousDataSourceRDD now spawns a ContinuousQueuedDataReader which is shared across multiple calls to compute() for the same partition. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #21200 from jose-torres/noAggr. --- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../continuous/ContinuousDataSourceRDD.scala | 114 +++++++++ .../ContinuousDataSourceRDDIter.scala | 222 ------------------ .../ContinuousQueuedDataReader.scala | 211 +++++++++++++++++ .../continuous/ContinuousWriteRDD.scala | 90 +++++++ .../WriteToContinuousDataSourceExec.scala | 57 +---- .../ContinuousQueuedDataReaderSuite.scala | 167 +++++++++++++ 7 files changed, 592 insertions(+), 275 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 41bdda47c8c3e..77cb707340b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -96,7 +96,11 @@ case class DataSourceV2ScanExec( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + readerFactories) .asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 0000000000000..0a3b9dcccb6c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.{NextIterator, ThreadUtils} + +class ContinuousDataSourceRDDPartition( + val index: Int, + val readerFactory: DataReaderFactory[UnsafeRow]) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readerFactories.zipWithIndex.map { + case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + }.toArray + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + if (partition.queueReader == null) { + partition.queueReader = + new ContinuousQueuedDataReader( + partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index 06754f01657d3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.util.ThreadUtils - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } - - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() - - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--$coordinatorId--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } finally { - reader.close() - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 0000000000000..01a999f6505fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + factory: DataReaderFactory[UnsafeRow], + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + private val reader = factory.createDataReader() + + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentEpoch: Long = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): UnsafeRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + context.partitionId(), currentEpoch, currentOffset)) + currentEpoch += 1 + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[DataReader]]. + */ + class DataReaderThread extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 0000000000000..91f1576581511 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch committed.") + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index ba88ae1af469a..e0af3a2f1b85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -46,24 +46,19 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } - val rdd = query.execute() + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + - s"The input RDD has ${rdd.getNumPartitions} partitions.") - // Let the epoch coordinator know how many partitions the write RDD has. + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) try { // Force the RDD to run so continuous processing starts; no data is actually being collected // to the driver, as ContinuousWriteRDD outputs nothing. - sparkContext.runJob( - rdd, - (context: TaskContext, iter: Iterator[InternalRow]) => - WriteToContinuousDataSourceExec.run(writerFactory, context, iter), - rdd.partitions.indices) + rdd.collect() } catch { case _: InterruptedException => // Interruption is how continuous queries are ended, so accept and ignore the exception. @@ -80,45 +75,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla sparkContext.emptyRDD } } - -object WriteToContinuousDataSourceExec extends Logging { - def run( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): Unit = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 0000000000000..e755625d09e0f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamWriter], + mock[ContinuousReader], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val factory = new DataReaderFactory[UnsafeRow] { + override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + } + val reader = new ContinuousQueuedDataReader( + factory, + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} From 47b5b68528c154d32b3f40f388918836d29462b8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 May 2018 16:35:24 -0700 Subject: [PATCH 429/593] [SPARK-24157][SS] Enabled no-data batches in MicroBatchExecution for streaming aggregation and deduplication. ## What changes were proposed in this pull request? This PR enables the MicroBatchExecution to run no-data batches if some SparkPlan requires running another batch to output results based on updated watermark / processing time. In this PR, I have enabled streaming aggregations and streaming deduplicates to automatically run addition batch even if new data is available. See https://issues.apache.org/jira/browse/SPARK-24156 for more context. Major changes/refactoring done in this PR. - Refactoring MicroBatchExecution - A major point of confusion in MicroBatchExecution control flow was always (at least to me) was that `populateStartOffsets` internally called `constructNextBatch` which was not obvious from just the name "populateStartOffsets" and made the control flow from the main trigger execution loop very confusing (main loop in `runActivatedStream` called `constructNextBatch` but only if `populateStartOffsets` hadn't already called it). Instead, the refactoring makes it cleaner. - `populateStartOffsets` only the updates `availableOffsets` and `committedOffsets`. Does not call `constructNextBatch`. - Main loop in `runActivatedStream` calls `constructNextBatch` which returns true or false reflecting whether the next batch is ready for executing. This method is now idempotent; if a batch has already been constructed, then it will always return true until the batch has been executed. - If next batch is ready then we call `runBatch` or sleep. - That's it. - Refactoring watermark management logic - This has been refactored out from `MicroBatchExecution` in a separate class to simplify `MicroBatchExecution`. - New method `shouldRunAnotherBatch` in `IncrementalExecution` - This returns true if there is any stateful operation in the last execution plan that requires another batch for state cleanup, etc. This is used to decide whether to construct a batch or not in `constructNextBatch`. - Changes to stream testing framework - Many tests used CheckLastBatch to validate answers. This assumed that there will be no more batches after the last set of input has been processed, so the last batch is the one that has output corresponding to the last input. This is not true anymore. To account for that, I made two changes. - `CheckNewAnswer` is a new test action that verifies the new rows generated since the last time the answer was checked by `CheckAnswer`, `CheckNewAnswer` or `CheckLastBatch`. This is agnostic to how many batches occurred between the last check and now. To do make this easier, I added a common trait between MemorySink and MemorySinkV2 to abstract out some common methods. - `assertNumStateRows` has been updated in the same way to be agnostic to batches while checking what the total rows and how many state rows were updated (sums up updates since the last check). ## How was this patch tested? - Changes made to existing tests - Tests have been changed in one of the following patterns. - Tests where the last input was given again to force another batch to be executed and state cleaned up / output generated, they were simplified by removing the extra input. - Tests using aggregation+watermark where CheckLastBatch were replaced with CheckNewAnswer to make them batch agnostic. - New tests added to check whether the flag works for streaming aggregation and deduplication Author: Tathagata Das Closes #21220 from tdas/SPARK-24157. --- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../streaming/IncrementalExecution.scala | 10 + .../streaming/MicroBatchExecution.scala | 231 ++++++++---------- .../streaming/WatermarkTracker.scala | 73 ++++++ .../sql/execution/streaming/memory.scala | 17 +- .../streaming/sources/memoryV2.scala | 8 +- .../streaming/statefulOperators.scala | 16 ++ .../sources/ForeachWriterSuite.scala | 8 +- .../streaming/EventTimeWatermarkSuite.scala | 112 ++++----- .../sql/streaming/FileStreamSinkSuite.scala | 7 +- .../sql/streaming/StateStoreMetricsTest.scala | 52 +++- .../spark/sql/streaming/StreamTest.scala | 56 ++++- ...cala => StreamingDeduplicationSuite.scala} | 94 +++---- .../sql/streaming/StreamingJoinSuite.scala | 18 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- 15 files changed, 450 insertions(+), 265 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{DeduplicateSuite.scala => StreamingDeduplicationSuite.scala} (80%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3942240c442b2..895e150756567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -919,6 +919,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -1313,6 +1321,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a83c884d55bd..c480b96626f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -143,4 +143,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6e231970f4a22..6709e7052f005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,6 +61,8 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private val watermarkTracker = new WatermarkTracker() + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -128,40 +130,55 @@ class MicroBatchExecution( * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (currentBatchIsRunnable) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // If the current batch has been executed, then increment the batch id, else there was + // no data to execute the batch + if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -211,6 +228,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -235,7 +253,6 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets - constructNextBatch() } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -243,19 +260,18 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -266,93 +282,63 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - case s: MicroBatchReader => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) - }.toMap - availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false + private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { + // If new data is already available that means this method has already been called before + // and it must have already committed the offset range of next batch to the offset log. + // Hence do nothing, just return true. + if (isNewDataAvailable) return true + + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } - } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -373,7 +359,7 @@ class MicroBatchExecution( reader.commit(reader.deserializeOffset(off.json)) } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -384,15 +370,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -400,6 +383,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -513,17 +498,17 @@ class MicroBatchExecution( } } - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. + withProgressLocked { + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() } + watermarkTracker.updateWatermark(lastExecution.executedPlan) + logDebug(s"Completed batch ${currentBatchId}") } /** Execute a function while locking the stream from making an progress */ - private[sql] def withProgressLocked(f: => Unit): Unit = { + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 0000000000000..80865669558dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan + +class WatermarkTracker extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var watermarkMs: Long = 0 + private var updated = false + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + watermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 + if (newWatermarkMs > watermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + watermarkMs = newWatermarkMs + updated = true + } else { + logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") + updated = false + } + } + + def currentWatermark: Long = synchronized { watermarkMs } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce7..22258274c70c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -222,11 +222,20 @@ class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) } } +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -236,7 +245,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -245,6 +254,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index d871d37ad37c1..0d6c239274dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { override def createStreamWriter( queryId: String, schema: StructType, @@ -67,6 +67,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c9354ac0ec78a..1691a6320a526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -126,6 +126,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -388,6 +394,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -454,6 +466,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 03bf71b3f4b78..e60c339bc9cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -211,14 +211,12 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. val allEvents = ForeachWriterSuite.allEvents() - assert(allEvents.size === 3) + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( ForeachWriterSuite.Open(partition = 0, version = 0), @@ -230,6 +228,10 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd ), Seq( ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 3), ForeachWriterSuite.Process(value = 3), ForeachWriterSuite.Close(None) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..7e8fde1ff8e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -137,20 +138,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.get("watermark") === formatTimestamp(5)) }, AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), CheckAnswer((10, 3)), assertEventStats { e => assert(e.get("max") === formatTimestamp(25)) assert(e.get("min") === formatTimestamp(25)) assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(5)) } ) } @@ -167,15 +160,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +183,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +241,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +380,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,8 +458,35 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index cf41d7e0e4fe1..ed53def556cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -279,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca8..e45f9d3e2e97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,58 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af0268fa47871..9d139a927bea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -192,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( @@ -202,6 +203,23 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + + private def operatorName = "CheckNewAnswer" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -435,13 +453,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { currentStream.awaitOffset(sourceIndex, offset) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } @@ -463,14 +492,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) - } - try if (lastOnly) latestBatchData else allData catch { + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } def executeAction(action: StreamAction): Unit = { @@ -704,6 +740,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } catch { case e: Throwable => failTest(e.toString) } + + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } } pos += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0088b64d6195e..42ffd472eb843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -97,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -141,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -260,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -279,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 11bdd13942dcb..da8f9608c1e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -192,7 +192,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 5, 11)), AddData(rightInput, (1, 10)), CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), @@ -276,14 +276,14 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), CheckLastBatch(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + assertNumStateRows(total = 12, updated = 5), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) @@ -573,7 +573,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 3, updated = 1) @@ -591,7 +591,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), assertNumStateRows(total = 3, updated = 1) @@ -630,7 +630,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 1, 5, 10)), AddData(rightInput, (1, 11)), CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), @@ -668,7 +668,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added AddData(rightInput, 20), CheckLastBatch( Row(20, 30, 40, 60)), @@ -678,7 +678,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + assertNumStateRows(total = 6, updated = 6), // all inputs added since last check AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), @@ -687,7 +687,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(), MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added AddData(rightInput, 1000), CheckLastBatch( Row(1000, 1010, 2000, 3000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 390d67d1feb27..0cb2375e0a49a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -334,7 +334,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) + assert(progress.sources(0).startOffset === "0") assert(progress.sources(0).endOffset !== null) assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms From dd4b1b9c7ccad3363a6a21524aed047fcd282f68 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 6 May 2018 10:25:01 +0800 Subject: [PATCH 430/593] [SPARK-24185][SPARKR][SQL] add flatten function to SparkR ## What changes were proposed in this pull request? add array flatten function to SparkR ## How was this patch tested? Unit tests were added in R/pkg/tests/fulltests/test_sparkSQL.R Author: Huaxin Gao Closes #21244 from huaxingao/spark-24185. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 14 ++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 4 files changed, 25 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f36d462a83cb0..8cd00352d1956 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -258,6 +258,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ec4bd4e73c7e5..0ec99d19e21e4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,6 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -3035,6 +3036,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{flatten}: Transforms an array of arrays into a single array. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 562d3399ee9c8..4ef12d19b3575 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 8cc2db7a140f9..3a8866bf2a88a 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1502,6 +1502,12 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + # Test flattern + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] From f38ea00e83099a5ae8d3afdec2e896e43c2db612 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 May 2018 20:41:32 -0700 Subject: [PATCH 431/593] [SPARK-24017][SQL] Refactor ExternalCatalog to be an interface ## What changes were proposed in this pull request? This refactors the external catalog to be an interface. It can be easier for the future work in the catalog federation. After the refactoring, `ExternalCatalog` is much cleaner without mixing the listener event generation logic. ## How was this patch tested? The existing tests Author: gatorsmile Closes #21122 from gatorsmile/refactorExternalCatalog. --- .../catalyst/catalog/ExternalCatalog.scala | 134 ++------ .../catalog/ExternalCatalogWithListener.scala | 298 ++++++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 26 +- .../catalog/ExternalCatalogEventSuite.scala | 2 +- .../spark/sql/internal/SharedState.scala | 9 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 26 +- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../sql/hive/HiveSessionStateBuilder.scala | 7 +- .../sql/hive/execution/SaveAsHiveFile.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 11 +- .../HiveExternalSessionCatalogSuite.scala | 2 +- .../sql/hive/HiveSchemaInferenceSuite.scala | 3 +- .../sql/hive/HiveSessionStateSuite.scala | 3 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 5 +- .../spark/sql/hive/ShowCreateTableSuite.scala | 3 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../sql/hive/test/TestHiveSingleton.scala | 2 +- 20 files changed, 384 insertions(+), 168 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c1..1a145c24d78cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 0000000000000..2f009be5816fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd52..741dc46b07382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a075..2fcaeca34db3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e3..5b6160e2b408f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index cbd75ad12d430..8980bcf885589 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,7 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 28c340a176d91..011a3ba553cb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -158,13 +158,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +480,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +489,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +624,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,7 +656,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { @@ -1208,7 +1208,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1221,12 +1221,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1235,7 +1235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index e5aff3b99d0b9..94ddeae1bf547 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalogBuilder: () => HiveExternalCatalog, + externalCatalogBuilder: () => ExternalCatalog, globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 40b9bb51ca9a0..2882672f327c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,14 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - new HiveSessionResourceLoader(session, () => externalCatalog.client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 6a7b25b36d9a5..e0f7375387d24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -122,7 +122,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 965aea2b61456..ee3f99ab7e9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -208,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0eac..fd5f47e428239 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d7..51a48a20daaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index ecc09cdcdbeaf..a3579862c9e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -44,7 +44,8 @@ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { val conf = sparkSession.sparkContext.hadoopConfiguration val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) sparkSession.cloneSession() - sparkSession.sharedState.externalCatalog.client.newSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) assert(oldValue == newValue, "cloneSession and then newSession should not affect the Derby directory") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 079fe45860544..aa5b531992613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -780,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474e..473bbced41b31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -289,7 +289,8 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6176273c88db1..dc96ec416afd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -134,8 +134,8 @@ class VersionsSuite extends SparkFunSuite with Logging { client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index daac6af9b557f..0341c3b378918 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1355,7 +1355,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1393,7 +1394,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 704a410b6a37b..828c18a770c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2099,7 +2099,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d3fff37c3424d..d50bf0b8fd603 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -30,7 +30,7 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { From a634d66ce767bd5e1d8553d1a2c32e2b1a80f642 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 7 May 2018 13:00:18 +0800 Subject: [PATCH 432/593] [SPARK-24126][PYSPARK] Use build-specific temp directory for pyspark tests. This avoids polluting and leaving garbage behind in /tmp, and allows the usual build tools to clean up any leftover files. Author: Marcelo Vanzin Closes #21198 from vanzin/SPARK-24126. --- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/streaming/tests.py | 6 ++++-- python/pyspark/tests.py | 33 ++++++++++++++++++++----------- python/run-tests.py | 29 +++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc6acfdb07d99..16aa9378ad8ee 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3092,8 +3092,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d77f1baa1f344..e4a428a0b27e7 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -1549,7 +1549,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8392d7f29af53..7b8ce2c6b799f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1951,7 +1951,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2017,7 +2022,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2033,7 +2038,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2051,7 +2056,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2070,7 +2075,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2087,8 +2092,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2103,9 +2110,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2124,7 +2133,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2144,7 +2153,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() diff --git a/python/run-tests.py b/python/run-tests.py index f408fc5082b3d..4c90926cfa350 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,11 +22,13 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: @@ -68,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -77,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -84,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -238,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -245,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() From 889f6cc10cbd7781df04f468674a61f0ac5a870b Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 7 May 2018 14:16:27 +0800 Subject: [PATCH 433/593] [SPARK-24143] filter empty blocks when convert mapstatus to (blockId, size) pair MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code(`MapOutputTracker.convertMapStatuses`), mapstatus are converted to (blockId, size) pair for all blocks – no matter the block is empty or not, which result in OOM when there are lots of consecutive empty blocks, especially when adaptive execution is enabled. (blockId, size) pair is only used in `ShuffleBlockFetcherIterator` to control shuffle-read and only non-empty block request is sent. Can we just filter out the empty blocks in MapOutputTracker.convertMapStatuses and save memory? ## How was this patch tested? not added yet. Author: jinxing Closes #21212 from jinxing64/SPARK-24143. --- .../org/apache/spark/MapOutputTracker.scala | 31 +++++++++------- .../storage/ShuffleBlockFetcherIterator.scala | 35 +++++++++++-------- .../apache/spark/MapOutputTrackerSuite.scala | 31 +++++++++++++++- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 19 +++++----- 5 files changed, 80 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b36..73646051f264c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -632,9 +632,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +643,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +670,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +843,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +860,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dd9df74689a13..6971efd2504c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d9..21f481d477242 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0))) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..2d8a83c6fabed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 692ae3bf597e0..cefebfa51b8b9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -99,7 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +176,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -310,7 +310,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -378,7 +378,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ) + ).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -437,7 +437,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -495,7 +495,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -513,14 +514,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. From 7564a9a70695dac2f0b5f51493d37cbc93691663 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 15:22:23 +0900 Subject: [PATCH 434/593] [SPARK-23921][SQL] Add array_sort function ## What changes were proposed in this pull request? The PR adds the SQL function `array_sort`. The behavior of the function is based on Presto's one. The function sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21021 from kiszk/SPARK-23921. --- python/pyspark/sql/functions.py | 26 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 240 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 34 ++- .../org/apache/spark/sql/functions.scala | 12 + .../spark/sql/DataFrameFunctionsSuite.scala | 34 ++- 6 files changed, 292 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad4bd6f5089e9..bd55b5f73b4d0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2183,20 +2183,38 @@ def array_max(col): def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe408..01776b85e6f53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -403,6 +403,7 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d63a531e3b74..23c09bc3b49d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,6 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -119,47 +120,16 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Common base class for [[SortArray]] and [[ArraySort]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } + protected def nullOrder: NullOrder @transient private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -170,9 +140,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - -1 + nullOrder } else if (o2 == null) { - 1 + -nullOrder } else { ordering.compare(o1, o2) } @@ -182,7 +152,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -193,9 +163,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - 1 + -nullOrder } else if (o2 == null) { - -1 + nullOrder } else { -ordering.compare(o1, o2) } @@ -203,18 +173,200 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } } - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType + def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + java.util.Arrays.sort(data, if (ascending) lt else gt) } new GenericArrayData(data.asInstanceOf[Array[Any]]) } + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } + s""" + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); + |} + """.stripMargin + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 + } +} + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + override def prettyName: String = "sort_array" } + +/** + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + /** * Returns a reversed string or an array with reverse order of elements. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7048d93fd5649..749374f1a14a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -61,28 +61,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2e22fa355514..10b6dcc0608c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3093,6 +3093,15 @@ object functions { ElementAt(column.expr, Literal(value)) } + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** * Creates a new row for each element in the given array or map column. * @@ -3332,6 +3341,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3341,6 +3351,8 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a5163accb1bb3..ae21cbc802d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,7 +276,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +286,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,6 +324,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } test("array size function") { From d2aa859b4faeda03e32a7574dd0c5b4ed367fae4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 7 May 2018 14:34:03 +0800 Subject: [PATCH 435/593] [SPARK-24160] ShuffleBlockFetcherIterator should fail if it receives zero-size blocks ## What changes were proposed in this pull request? This patch modifies `ShuffleBlockFetcherIterator` so that the receipt of zero-size blocks is treated as an error. This is done as a preventative measure to guard against a potential source of data loss bugs. In the shuffle layer, we guarantee that zero-size blocks will never be requested (a block containing zero records is always 0 bytes in size and is marked as empty such that it will never be legitimately requested by executors). However, the existing code does not fully take advantage of this invariant in the shuffle-read path: the existing code did not explicitly check whether blocks are non-zero-size. Additionally, our decompression and deserialization streams treat zero-size inputs as empty streams rather than errors (EOF might actually be treated as "end-of-stream" in certain layers (longstanding behavior dating to earliest versions of Spark) and decompressors like Snappy may be tolerant to zero-size inputs). As a result, if some other bug causes legitimate buffers to be replaced with zero-sized buffers (due to corruption on either the send or receive sides) then this would translate into silent data loss rather than an explicit fail-fast error. This patch addresses this problem by adding a `buf.size != 0` check. See code comments for pointers to tests which guarantee the invariants relied on here. ## How was this patch tested? Existing tests (which required modifications, since some were creating empty buffers in mocks). I also added a test to make sure we fail on zero-size blocks. To test that the zero-size blocks are indeed a potential corruption source, I manually ran a workload in `spark-shell` with a modified build which replaces all buffers with zero-size buffers in the receive path. Author: Josh Rosen Closes #21219 from JoshRosen/SPARK-24160. --- .../storage/ShuffleBlockFetcherIterator.scala | 19 +++++ .../ShuffleBlockFetcherIteratorSuite.scala | 71 +++++++++++++------ 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6971efd2504c2..b31862323a895 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -414,6 +414,25 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cefebfa51b8b9..8e9374b768adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are not checked for corruption") { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - doReturn(10000L).when(corruptBuffer).size() + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -428,9 +424,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -527,4 +523,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } From c5981976f1d514a3ad8a684b9a21cebe38b786fa Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 7 May 2018 14:45:14 +0800 Subject: [PATCH 436/593] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `onTaskStart` where stage ID is provided. In order to make sure cancelStage called for all stages `waitUntilEmpty` is called on `ListenerBus`. In [PR20888](https://github.com/apache/spark/pull/20888) this tried to get solved by: * Stopping the executor thread with `wait` * Wait for all `cancelStage` called * Kill the executor thread by setting `SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL` but the thread killing left the shared `SparkContext` sometimes in a state where further jobs can't be submitted. As a result DataFrameRangeSuite.test("Cancelling stage in a query with Range.") test passed properly but the next test inside the suite was hanging. ## How was this patch tested? Existing unit test executed 10k times. Author: Gabor Somogyi Closes #21214 from gaborgsomogyi/SPARK-23775_1. --- .../spark/sql/DataFrameRangeSuite.scala | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..b0b46640ff317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -23,8 +23,8 @@ import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -153,23 +153,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) - } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) } } sparkContext.addSparkListener(listener) for (codegen <- Seq(true, false)) { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } ex.getCause() match { case null => @@ -180,6 +174,8 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } } + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) eventually(timeout(20.seconds)) { assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } @@ -204,7 +200,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From f06528015d5856d6dc5cce00309bc2ae985e080f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 15:42:10 +0800 Subject: [PATCH 437/593] [SPARK-24160][FOLLOWUP] Fix compilation failure ## What changes were proposed in this pull request? SPARK-24160 is causing a compilation failure (after SPARK-24143 was merged). This fixes the issue. ## How was this patch tested? building successfully Author: Marco Gaido Closes #21256 from mgaido91/SPARK-24160_FOLLOWUP. --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 8e9374b768adc..a2997dbd1b1ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -546,7 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext, transfer, blockManager, - blocksByAddress, + blocksByAddress.toIterator, (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, From e35ad3caddeaa4b0d4c8524dcfb9e9f56dc7fe3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 16:57:37 +0900 Subject: [PATCH 438/593] [SPARK-23930][SQL] Add slice function ## What changes were proposed in this pull request? The PR add the `slice` function. The behavior of the function is based on Presto's one. The function slices an array according to the requested start index and length. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21040 from mgaido91/SPARK-23930. --- python/pyspark/sql/functions.py | 13 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 163 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 28 +++ .../expressions/ExpressionEvalHelper.scala | 6 + .../expressions/ObjectExpressionsSuite.scala | 1 - .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++ 9 files changed, 233 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bd55b5f73b4d0..ac3c79766702c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,19 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + @ignore_unicode_prefix @since(2.4) def array_join(col, delimiter, null_replacement=None): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 01776b85e6f53..87b0911e150c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Slice]("slice"), expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..4dda525294259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -730,6 +731,39 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 23c09bc3b49d7..12b9ab2b272ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -530,6 +529,129 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + /** * Creates a String containing all the elements of the input array separated by the delimiter. */ @@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends Expression { } private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { | for (int z = 0; z < args[y].numElements(); z++) { @@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + - | " bytes for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[(int)$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 749374f1a14a1..a2851d071c7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + test("ArrayJoin") { def testArrays( arrays: Seq[Expression], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..a22e9d4655e8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 730b36c32333c..77ca640f2e0bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b6dcc0608c2..8f9e4ae18b3f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + /** * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with * `nullReplacement`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ae21cbc802d0a..ecce06f4c0755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + test("array_join function") { val df = Seq( (Seq[String]("a", "b"), ","), From 4e861db5f149e10fd8dfe6b3c1484821a590b1e8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 7 May 2018 11:21:22 +0200 Subject: [PATCH 439/593] [SPARK-16406][SQL] Improve performance of LogicalPlan.resolve ## What changes were proposed in this pull request? `LogicalPlan.resolve(...)` uses linear searches to find an attribute matching a name. This is fine in normal cases, but gets problematic when you try to resolve a large number of columns on a plan with a large number of attributes. This PR adds an indexing structure to `resolve(...)` in order to find potential matches quicker. This PR improves the reference resolution time for the following code by 4x (11.8s -> 2.4s): ``` scala val n = 4000 val values = (1 to n).map(_.toString).mkString(", ") val columns = (1 to n).map("column" + _).mkString(", ") val query = s""" |SELECT $columns |FROM VALUES ($values) T($columns) |WHERE 1=2 AND 1 IN ($columns) |GROUP BY $columns |ORDER BY $columns |""".stripMargin spark.time(sql(query)) ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #14083 from hvanhovell/SPARK-16406. --- .../sql/catalyst/expressions/package.scala | 86 ++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 108 ++---------------- 2 files changed, 93 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af7..8a06daa37132d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 42034403d6d03..e487693927ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,105 +118,7 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** From d83e9637246b05eea202add07a168688f6c0481b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 7 May 2018 17:54:39 +0200 Subject: [PATCH 440/593] [SPARK-24043][SQL] Interpreted Predicate should initialize nondeterministic expressions ## What changes were proposed in this pull request? When creating an InterpretedPredicate instance, initialize any Nondeterministic expressions in the expression tree to avoid java.lang.IllegalArgumentException on later call to eval(). ## How was this patch tested? - sbt SQL tests - python SQL tests - new unit test Author: Bruce Robbins Closes #21144 from bersprockets/interpretedpredicate. --- .../spark/sql/catalyst/expressions/predicates.scala | 8 ++++++++ .../spark/sql/catalyst/expressions/PredicateSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..f8c6dc4e6adc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,6 +36,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -449,4 +449,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } From 56a52e0a58fc82ea69e47d0d8c4f905565be7c8b Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 7 May 2018 14:47:58 -0700 Subject: [PATCH 441/593] [SPARK-15750][MLLIB][PYSPARK] Constructing FPGrowth fails when no numPartitions specified in pyspark ## What changes were proposed in this pull request? Change FPGrowth from private to private[spark]. If no numPartitions is specified, then default value -1 is used. But -1 is only valid in the construction function of FPGrowth, but not in setNumPartitions. So I make this change and use the constructor directly rather than using set method. ## How was this patch tested? Unit test is added Author: Jeff Zhang Closes #13493 from zjffdu/SPARK-15750. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 5 +---- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae59..db3f074ecfbac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d16..4f2b7e6f0764e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 14d788b0bef60..4c2ce137e331c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,6 +1763,17 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: From 1c9c5de951ed86290bcd7d8edaab952b8cacd290 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 May 2018 14:52:14 -0700 Subject: [PATCH 442/593] [SPARK-23291][SPARK-23291][R][FOLLOWUP] Update SparkR migration note for ## What changes were proposed in this pull request? This PR fixes the migration note for SPARK-23291 since it's going to backport to 2.3.1. See the discussion in https://issues.apache.org/jira/browse/SPARK-23291 ## How was this patch tested? N/A Author: hyukjinkwon Closes #21249 from HyukjinKwon/SPARK-23291. --- docs/sparkr.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index 7fabab5d38f16..4faad2c4c1824 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -664,6 +664,6 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. -## Upgrading to Spark 2.4.0 +## Upgrading to SparkR 2.3.1 and above - - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. From f48bd6bdc5aefd9ec43e2d0ee648d17add7ef554 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:55:41 -0700 Subject: [PATCH 443/593] [SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning ## What changes were proposed in this pull request? ML test for StructuredStreaming: spark.ml.tuning ## How was this patch tested? N/A Author: WeichenXu Closes #20261 from WeichenXu123/ml_stream_tuning_test. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 15 +++++++++++---- .../ml/tuning/TrainValidationSplitSuite.scala | 15 +++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade2627090..e6ee7220d2279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c831..cd76acf9c67bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { From 76ecd095024a658bf68e5db658e4416565b30c17 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:57:14 -0700 Subject: [PATCH 444/593] [SPARK-20114][ML] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? PrefixSpan API for spark.ml. New implementation instead of #20810 ## How was this patch tested? TestSuite added. Author: WeichenXu Closes #20973 from WeichenXu123/prefixSpan2. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 96 +++++++++++++ .../apache/spark/mllib/fpm/PrefixSpan.scala | 3 +- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 136 ++++++++++++++++++ 3 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 0000000000000..02168fee16dbf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +object PrefixSpan { + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{Seq[Seq[_]]}}} type + * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column + * are ignored + * @param minSupport the minimal support level of the sequential pattern, any pattern that + * appears more than (minSupport * size-of-the-dataset) times will be output + * (recommended value: `0.1`). + * @param maxPatternLength the maximal length of the sequential pattern + * (recommended value: `10`). + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the + * internal storage format) allowed in a projected database before + * local processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run + * (recommended value: `32000000`). + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns( + dataset: Dataset[_], + sequenceCol: String, + minSupport: Double, + maxPatternLength: Int, + maxLocalProjDBSize: Long): DataFrame = { + + val inputType = dataset.schema(sequenceCol).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + + val data = dataset.select(sequenceCol) + val sequences = data.where(col(sequenceCol).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(maxLocalProjDBSize) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2c..7aed2f3bd8a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000000..9e538696cbcf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", + minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + From 0d63eb8888d17df747fb41d7ba254718bb7af3ae Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 7 May 2018 20:08:41 -0700 Subject: [PATCH 445/593] [SPARK-23975][ML] Add support of array input for all clustering methods ## What changes were proposed in this pull request? Add support for all of the clustering methods ## How was this patch tested? unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21195 from ludatabricks/SPARK-23975-1. --- .../spark/ml/clustering/BisectingKMeans.scala | 21 ++++---- .../spark/ml/clustering/GaussianMixture.scala | 12 +++-- .../apache/spark/ml/clustering/KMeans.scala | 31 +++--------- .../org/apache/spark/ml/clustering/LDA.scala | 9 ++-- .../apache/spark/ml/util/DatasetUtils.scala | 13 ++++- .../apache/spark/ml/util/SchemaUtils.scala | 16 ++++++- .../ml/clustering/BisectingKMeansSuite.scala | 21 +++++++- .../ml/clustering/GaussianMixtureSuite.scala | 21 +++++++- .../spark/ml/clustering/KMeansSuite.scala | 48 ++++++------------- .../apache/spark/ml/clustering/LDASuite.scala | 20 +++++++- .../apache/spark/ml/util/MLTestingUtils.scala | 23 ++++++++- 11 files changed, 147 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index addc12ac52ec1..438e53ba6197c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,17 +22,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -75,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -113,7 +111,8 @@ class BisectingKMeansModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -132,9 +131,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -260,9 +259,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index b5804900c0358..88d618c3a03a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[Vector] = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index de61c9c089a36..97f246fbfd859 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, PipelineStage} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,24 +86,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) - /** - * Validates the input schema. - * @param schema input schema - */ - private[clustering] def validateSchema(schema: StructType): Unit = { - val typeCandidates = List( new VectorUDT, - new ArrayType(DoubleType, false), - new ArrayType(FloatType, false)) - - SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) - } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateSchema(schema) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -160,12 +149,8 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - validateSchema(dataset.schema) - - val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } @@ -351,11 +336,7 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select( - DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 47077230fac0a..afe599cd167cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -461,7 +461,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -938,7 +939,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 52619cb65489a..6af4b3ebc2cc2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.linalg.{Vectors, VectorUDT} -import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} @@ -60,4 +62,11 @@ private[spark] object DatasetUtils { throw new IllegalArgumentException(s"$other column cannot be cast to Vector") } } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -101,4 +102,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 02880f96ae6d9..f3ff2afcad2cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -182,6 +185,22 @@ class BisectingKMeansSuite model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e4183..d0d461a42711a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap @@ -24,8 +26,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - +import org.apache.spark.sql.{DataFrame, Dataset, Row} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 5445ebe5c95eb..680a7c2034083 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} @@ -25,13 +26,11 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, - KMeansModel => MLlibKMeansModel} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -202,38 +201,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("KMean with Array input") { - val featuresColNameD = "array_double_features" - val featuresColNameF = "array_float_features" - - val doubleUDF = udf { (features: Vector) => - val featureArray = Array.fill[Double](features.size)(0.0) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray - } - val floatUDF = udf { (features: Vector) => - val featureArray = Array.fill[Float](features.size)(0.0f) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) } - val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) - .drop("features") - val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) - .drop("features") - assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) - assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) - - val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) - val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) - val modelD = kmeansD.fit(newdatasetD) - val modelF = kmeansF.fit(newdatasetF) - val transformedD = modelD.transform(newdatasetD) - val transformedF = modelF.transform(newdatasetF) - - val predictDifference = transformedD.select("prediction") - .except(transformedF.select("prediction")) - assert(predictDifference.count() == 0) - assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76bd..8d728f063dd8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite @@ -26,7 +28,6 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -323,4 +324,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c328d81b4bc3a..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -247,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } From cd12c5c3ecf28f7b04f566c2057f9b65eb456b7d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 8 May 2018 12:21:33 +0800 Subject: [PATCH 446/593] [SPARK-24128][SQL] Mention configuration option in implicit CROSS JOIN error ## What changes were proposed in this pull request? Mention `spark.sql.crossJoin.enabled` in error message when an implicit `CROSS JOIN` is detected. ## How was this patch tested? `CartesianProductSuite` and `JoinSuite`. Author: Henry Robinson Closes #21201 from henryr/spark-24128. --- R/pkg/tests/fulltests/test_sparkSQL.R | 4 ++-- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 ++++-- .../catalyst/optimizer/CheckCartesianProductsSuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 3a8866bf2a88a..43725e0ebd3bf 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2210,8 +2210,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 45f13956a0a85..bfa61116a6658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1182,12 +1182,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e8..788fedb3c8e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63ab..8fa747465cb1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) From 05eb19b6e09065265358eec2db2ff3b42806dfc9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 May 2018 14:32:04 +0800 Subject: [PATCH 447/593] [SPARK-24188][CORE] Restore "/version" API endpoint. It was missing the jax-rs annotation. Author: Marcelo Vanzin Closes #21245 from vanzin/SPARK-24188. Change-Id: Ib338e34b363d7c729cc92202df020dc51033b719 --- .../org/apache/spark/status/api/v1/ApiRootResource.scala | 1 + .../org/apache/spark/deploy/history/HistoryServerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 7127397f6205c..d121068718b8a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87f12f303cd5e..a871b1c717837 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -296,6 +296,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) From e17567ca78dbb416039c17da212957c8955bfa65 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 8 May 2018 11:34:27 +0200 Subject: [PATCH 448/593] [SPARK-24076][SQL] Use different seed in HashAggregate to avoid hash conflict ## What changes were proposed in this pull request? HashAggregate uses the same hash algorithm and seed as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n. Considering below example: ``` SET spark.sql.shuffle.partitions=8192; INSERT OVERWRITE TABLE target_xxx SELECT item_id, auct_end_dt FROM from source_xxx GROUP BY item_id, auct_end_dt; ``` In the shuffle stage, if user sets the shuffle.partition = 8192, all tuples in the same partition will meet the following relationship: ``` hash(tuple x) = hash(tuple y) + n * 8192 ``` Then in the next HashAggregate stage, all tuples from the same partition need be put into a 16K BytesToBytesMap (unsafeRowAggBuffer). Here, the HashAggregate uses the same hash algorithm on the same expression as shuffle, and uses the same seed, and 16K = 8192 * 2, so actually, all tuples in the same parititon will only be hashed to 2 different places in the BytesToBytesMap. It is bad hash conflict. With BytesToBytesMap growing, the conflict will always exist. Before change: hash_conflict After change: no_hash_conflict ## How was this patch tested? Unit tests and production cases. Author: yucai Closes #21149 from yucai/SPARK-24076. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..6a8ec4f722aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -755,7 +755,10 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, From b54bbe57b33b00063596cd9588fa2461745ed571 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 May 2018 21:22:54 +0800 Subject: [PATCH 449/593] [SPARK-24131][PYSPARK][FOLLOWUP] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? More close to Scala API behavior when can't parse input by throwing exception. Add tests. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21211 from viirya/SPARK-24131-followup. --- python/pyspark/tests.py | 4 ++++ python/pyspark/util.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7b8ce2c6b799f..498d6b57e4353 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2312,6 +2312,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 04df835bf6717..59cc2a6329350 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -62,24 +62,31 @@ def _get_argspec(f): return argspec -def majorMinorVersion(version): +class VersionUtils(object): """ - Get major and minor version numbers for given Spark version string. - - >>> version = "2.4.0" - >>> majorMinorVersion(version) - (2, 4) + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). - >>> version = "abc" - >>> majorMinorVersion(version) is None - True + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) - """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - return None - else: - return (int(m.group(1)), int(m.group(2))) + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") if __name__ == "__main__": From 2f6fe7d679a878ffd103cac6f06081c5b3888744 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 May 2018 21:24:35 +0800 Subject: [PATCH 450/593] [SPARK-23094][SPARK-23723][SPARK-23724][SQL][FOLLOW-UP] Support custom encoding for json files ## What changes were proposed in this pull request? This is to add a test case to check the behaviors when users write json in the specified UTF-16/UTF-32 encoding with multiline off. ## How was this patch tested? N/A Author: gatorsmile Closes #21254 from gatorsmile/followupSPARK-23094. --- .../spark/sql/catalyst/json/JSONOptions.scala | 9 +++++---- .../datasources/json/JsonSuite.scala | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f130af606e19..2579374e3f4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -110,11 +110,12 @@ private[sql] class JSONOptions( val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) val isBlacklisted = blacklist.contains(Charset.forName(enc)) require(multiLine || !isBlacklisted, - s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: - | ${blacklist.mkString(", ")}""".stripMargin) + s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. + |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - val isLineSepRequired = !(multiLine == false && - Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") enc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0db688fec9a67..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2313,6 +2313,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq( + ("a", 1), ("b", 2), ("c", 3)) + ).repartition(2) + val e = intercept[IllegalArgumentException] { + ds.write + .option("encoding", encoding) + .option("multiline", "false") + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + }.getMessage + assert(e.contains( + s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + } + } + } + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { val schema = new StructType().add("f1", StringType).add("f2", IntegerType) From 487faf17ab96c8edb729501dfb1ff82f7b2c6031 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 8 May 2018 23:43:02 +0800 Subject: [PATCH 451/593] [SPARK-24117][SQL] Unified the getSizePerRow ## What changes were proposed in this pull request? This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example: 1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80) 2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36) ## How was this patch tested? Exist tests Author: Yuming Wang Closes #21189 from wangyum/SPARK-24117. --- .../sql/catalyst/plans/logical/LocalRelation.scala | 3 ++- .../logical/statsEstimation/EstimationUtils.scala | 14 ++++++++------ .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 10 ++++------ .../sql/execution/streaming/sources/memoryV2.scala | 3 ++- .../spark/sql/StatisticsCollectionSuite.scala | 2 +- .../sql/execution/streaming/MemorySinkSuite.scala | 4 ++-- 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 720d42ab409a0..8c4828a4cef23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -77,7 +78,7 @@ case class LocalRelation( } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0f147f0ffb135..211a2a0717371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -73,13 +71,12 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => + 8 + attributes.map { attr => if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => @@ -92,10 +89,15 @@ object EstimationUtils { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 85f67c7d66075..ee43f9126386b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 22258274c70c1..6720cdd24b1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 0d6c239274dd8..468313bfe8c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} @@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b91712f4cc25d..60fa951e23178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9d..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { From e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 9 May 2018 08:32:20 +0800 Subject: [PATCH 452/593] [SPARK-24068] Propagating DataFrameReader's options to Text datasource on schema inferring ## What changes were proposed in this pull request? While reading CSV or JSON files, DataFrameReader's options are converted to Hadoop's parameters, for example there: https://github.com/apache/spark/blob/branch-2.3/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L302 but the options are not propagated to Text datasource on schema inferring, for instance: https://github.com/apache/spark/blob/branch-2.3/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala#L184-L188 The PR proposes propagation of user's options to Text datasource on scheme inferring in similar way as user's options are converted to Hadoop parameters if schema is specified. ## How was this patch tested? The changes were tested manually by using https://github.com/twitter/hadoop-lzo: ``` hadoop-lzo> mvn clean package hadoop-lzo> ln -s ./target/hadoop-lzo-0.4.21-SNAPSHOT.jar ./hadoop-lzo.jar ``` Create 2 test files in JSON and CSV format and compress them: ```shell $ cat test.csv col1|col2 a|1 $ lzop test.csv $ cat test.json {"col1":"a","col2":1} $ lzop test.json ``` Run `spark-shell` with hadoop-lzo: ``` bin/spark-shell --jars ~/hadoop-lzo/hadoop-lzo.jar ``` reading compressed CSV and JSON without schema: ```scala spark.read.option("io.compression.codecs", "com.hadoop.compression.lzo.LzopCodec").option("inferSchema",true).option("header",true).option("sep","|").csv("test.csv.lzo").show() +----+----+ |col1|col2| +----+----+ | a| 1| +----+----+ ``` ```scala spark.read.option("io.compression.codecs", "com.hadoop.compression.lzo.LzopCodec").option("multiLine", true).json("test.json.lzo").printSchema root |-- col1: string (nullable = true) |-- col2: long (nullable = true) ``` Author: Maxim Gekk Author: Maxim Gekk Closes #21182 from MaxGekk/text-options. --- .../apache/spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/execution/datasources/csv/CSVDataSource.scala | 6 ++++-- .../sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../spark/sql/execution/datasources/csv/CSVUtils.scala | 2 -- .../execution/datasources/csv/UnivocityParser.scala | 2 -- .../execution/datasources/json/JsonDataSource.scala | 10 ++++++---- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2579374e3f4e1..2ff12acb2946f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index bc1f4ab3bb053..dc54d182651b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -185,7 +185,8 @@ object TextInputCSVDataSource extends CSVDataSource { DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = options.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { @@ -250,7 +251,8 @@ object MultiLineCSVDataSource extends CSVDataSource { options: CSVOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) val name = paths.mkString(",") - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + options.parameters)) FileInputFormat.setInputPaths(job, paths: _*) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2ec0fc605a84b..ed2dc65a47914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 31464f1bcc68e..9dae41b63e810 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv -import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3d6cc30f2ba83..99557a1ceb0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.InputStream import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.Try import scala.util.control.NonFatal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 983a5f0dcade2..ba83df0efebd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -121,7 +121,7 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession, paths = paths, className = classOf[TextFileFormat].getName, - options = textOptions + options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -159,7 +159,7 @@ object MultiLineJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) val parser = parsedOptions.encoding .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) @@ -170,9 +170,11 @@ object MultiLineJsonDataSource extends JsonDataSource { private def createBaseRdd( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + parsedOptions.parameters)) val conf = job.getConfiguration val name = paths.mkString(",") FileInputFormat.setInputPaths(job, paths: _*) From 9498e528d21e286e496da6ea9bf9c7ad73a7b5bd Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 9 May 2018 08:39:46 +0800 Subject: [PATCH 453/593] [SPARK-23355][SQL][DOC][FOLLOWUP] Add migration doc for TBLPROPERTIES ## What changes were proposed in this pull request? In Apache Spark 2.4, [SPARK-23355](https://issues.apache.org/jira/browse/SPARK-23355) fixes a bug which ignores table properties during convertMetastore for tables created by STORED AS ORC/PARQUET. For some Parquet tables having table properties like TBLPROPERTIES (parquet.compression 'NONE'), it was ignored by default before Apache Spark 2.4. After upgrading cluster, Spark will write uncompressed file which is different from Apache Spark 2.3 and old. This PR adds a migration note for that. ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #21269 from dongjoon-hyun/SPARK-23355-DOC. --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 075b953a0898e..c521f3cb51e58 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1812,6 +1812,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. From 7e7350285dc22764f599671d874617c0eea093e5 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 8 May 2018 21:20:58 -0700 Subject: [PATCH 454/593] [SPARK-24132][ML] Instrumentation improvement for classification ## What changes were proposed in this pull request? - Add OptionalInstrumentation as argument for getNumClasses in ml.classification.Classifier - Change the function call for getNumClasses in train() in ml.classification.DecisionTreeClassifier, ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes - Modify the instrumentation creation in ml.classification.LinearSVC - Change the log call in ml.classification.OneVsRest and ml.classification.LinearSVC ## How was this patch tested? Manual. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21204 from ludatabricks/SPARK-23686. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 9 ++++++--- .../org/apache/spark/ml/classification/LinearSVC.scala | 9 ++++++--- .../org/apache/spark/ml/classification/NaiveBayes.scala | 3 ++- .../org/apache/spark/ml/classification/OneVsRest.scala | 4 ++-- .../spark/ml/classification/RandomForestClassifier.scala | 4 +++- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 57797d1cc4978..c9786f1f7ceb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") ( override def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + instr.logNumClasses(numClasses) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) @@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeClassificationModel = { val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 80c537e1e0eb2..38eb04556b775 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) @@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") ( bcFeaturesStd.destroy(blocking = false) if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 45fb585ed2262..1dde18d2d1a31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") ( private[spark] def trainWithLabelCheck( dataset: Dataset[_], positiveLabel: Boolean): NaiveBayesModel = { + val instr = Instrumentation.create(this, dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) + instr.logNumClasses(numClasses) require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") @@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") ( } } - val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7df53a6b8ad10..3474b61e40136 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") ( transformSchema(dataset.schema) val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism) + instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") ( getClassifier match { case _: HasWeightCol => true case c => - logWarning(s"weightCol is ignored, as it is not supported by $c now.") + instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.") false } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f1ef26a07d3f8..040db3b94b041 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") ( set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) instr.logSuccess(m) m } From cac9b1dea1bb44fa42abf77829c05bf93f70cf20 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 9 May 2018 12:27:32 +0800 Subject: [PATCH 455/593] [SPARK-23972][BUILD][SQL] Update Parquet to 1.10.0. ## What changes were proposed in this pull request? This updates Parquet to 1.10.0 and updates the vectorized path for buffer management changes. Parquet 1.10.0 uses ByteBufferInputStream instead of byte arrays in encoders. This allows Parquet to break allocations into smaller chunks that are better for garbage collection. ## How was this patch tested? Existing Parquet tests. Running in production at Netflix for about 3 months. Author: Ryan Blue Closes #21070 from rdblue/SPARK-23972-update-parquet-to-1.10.0. --- dev/deps/spark-deps-hadoop-2.6 | 12 +- dev/deps/spark-deps-hadoop-2.7 | 12 +- dev/deps/spark-deps-hadoop-3.1 | 12 +- docs/sql-programming-guide.md | 2 +- pom.xml | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../parquet/VectorizedColumnReader.java | 39 ++-- .../parquet/VectorizedPlainValuesReader.java | 166 +++++++++++------- .../parquet/VectorizedRleValuesReader.java | 163 ++++++++--------- .../datasources/parquet/ParquetOptions.scala | 5 +- .../describe-part-after-analyze.sql.out | 12 +- .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- 13 files changed, 241 insertions(+), 198 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c3d1dd444b506..f479c13f00be6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -162,13 +162,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 290867035f91d..e7c4599cb5003 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -163,13 +163,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 97ad65a4096cb..3447cd7395d95 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -181,13 +181,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c521f3cb51e58..3e8946e424237 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession Sets the compression codec used when writing Parquet files. If either `compression` or `parquet.compression` is specified in the table-specific options/properties, the precedence would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: - none, uncompressed, snappy, gzip, lzo. + none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. diff --git a/pom.xml b/pom.xml index 88e77ff874748..6e37e518d86e4 100644 --- a/pom.xml +++ b/pom.xml @@ -129,7 +129,7 @@ 1.2.110.12.1.1 - 1.8.2 + 1.10.01.4.3nohive1.6.0 @@ -1778,6 +1778,12 @@ parquet-hadoop${parquet.version}${parquet.deps.scope} + + + commons-pool + commons-pool + + org.apache.parquet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 895e150756567..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -345,7 +345,7 @@ object SQLConf { "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e65cd252c3ddf..10d6ed85a4080 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -293,7 +293,7 @@ protected static IntIterator createRLEIterator( return new RLEIntIterator( new RunLengthBitPackingHybridDecoder( BytesUtils.getWidthFromMaxInt(maxLevel), - new ByteArrayInputStream(bytes.toByteArray()))); + bytes.toInputStream())); } catch (IOException e) { throw new IOException("could not read levels in page for col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 72f1d024b08ce..d5969b55eef96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.TimeZone; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -388,7 +390,8 @@ private void decodeDictionaryIds( * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) + throws IOException { if (column.dataType() != DataTypes.BooleanType) { throw constructConvertNotSupportedException(descriptor, column); } @@ -396,7 +399,7 @@ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, WritableColumnVector column) { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -414,7 +417,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { } } - private void readLongBatch(int rowId, int num, WritableColumnVector column) { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType()) || @@ -434,7 +437,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } - private void readFloatBatch(int rowId, int num, WritableColumnVector column) { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -445,7 +448,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { } } - private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -456,7 +459,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { } } - private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -556,7 +559,7 @@ public Void visit(DataPageV2 dataPageV2) { }); } - private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { + private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException { this.endOfPageValueCount = valuesRead + pageValueCount; if (dataEncoding.usesDictionary()) { this.dataColumn = null; @@ -581,7 +584,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } try { - dataColumn.initFromPage(pageValueCount, bytes, offset); + dataColumn.initFromPage(pageValueCount, in); } catch (IOException e) { throw new IOException("could not read page in col " + descriptor, e); } @@ -602,12 +605,11 @@ private void readPageV1(DataPageV1 page) throws IOException { this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { - byte[] bytes = page.getBytes().toByteArray(); - rlReader.initFromPage(pageValueCount, bytes, 0); - int next = rlReader.getNextOffset(); - dlReader.initFromPage(pageValueCount, bytes, next); - next = dlReader.getNextOffset(); - initDataReader(page.getValueEncoding(), bytes, next); + BytesInput bytes = page.getBytes(); + ByteBufferInputStream in = bytes.toInputStream(); + rlReader.initFromPage(pageValueCount, in); + dlReader.initFromPage(pageValueCount, in); + initDataReader(page.getValueEncoding(), in); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } @@ -619,12 +621,13 @@ private void readPageV2(DataPageV2 page) throws IOException { page.getRepetitionLevels(), descriptor); int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - this.defColumn = new VectorizedRleValuesReader(bitWidth); + // do not read the length from the stream. v2 pages handle dividing the page bytes. + this.defColumn = new VectorizedRleValuesReader(bitWidth, false); this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); - this.defColumn.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.defColumn.initFromPage( + this.pageValueCount, page.getDefinitionLevels().toInputStream()); try { - initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); + initDataReader(page.getDataEncoding(), page.getData().toInputStream()); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 5b75f719339fb..aacefacfc1c1a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,34 +20,30 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.ParquetDecodingException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. */ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { - private byte[] buffer; - private int offset; - private int bitOffset; // Only used for booleans. - private ByteBuffer byteBuffer; // used to wrap the byte array buffer + private ByteBufferInputStream in = null; - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // Only used for booleans. + private int bitOffset; + private byte currentByte = 0; public VectorizedPlainValuesReader() { } @Override - public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { - this.buffer = bytes; - this.offset = offset + Platform.BYTE_ARRAY_OFFSET; - if (bigEndianPlatform) { - byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); - } + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; } @Override @@ -63,115 +59,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + private ByteBuffer getBuffer(int length) { + try { + return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read " + length + " bytes", e); + } + } + @Override public final void readIntegers(int total, WritableColumnVector c, int rowId) { - c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putIntsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putInt(rowId + i, buffer.getInt()); + } + } } @Override public final void readLongs(int total, WritableColumnVector c, int rowId) { - c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putLongsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, buffer.getLong()); + } + } } @Override public final void readFloats(int total, WritableColumnVector c, int rowId) { - c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putFloats(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putFloat(rowId + i, buffer.getFloat()); + } + } } @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { - c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putDoubles(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putDouble(rowId + i, buffer.getDouble()); + } + } } @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { - for (int i = 0; i < total; i++) { - // Bytes are stored as a 4-byte little endian int. Just read the first byte. - // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putByte(rowId + i, Platform.getByte(buffer, offset)); - offset += 4; + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + for (int i = 0; i < total; i += 1) { + c.putByte(rowId + i, buffer.get()); + // skip the next 3 bytes + buffer.position(buffer.position() + 3); } } @Override public final boolean readBoolean() { - byte b = Platform.getByte(buffer, offset); - boolean v = (b & (1 << bitOffset)) != 0; + // TODO: vectorize decoding and keep boolean[] instead of currentByte + if (bitOffset == 0) { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + + boolean v = (currentByte & (1 << bitOffset)) != 0; bitOffset += 1; if (bitOffset == 8) { bitOffset = 0; - offset++; } return v; } @Override public final int readInteger() { - int v = Platform.getInt(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Integer.reverseBytes(v); - } - offset += 4; - return v; + return getBuffer(4).getInt(); } @Override public final long readLong() { - long v = Platform.getLong(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Long.reverseBytes(v); - } - offset += 8; - return v; + return getBuffer(8).getLong(); } @Override public final byte readByte() { - return (byte)readInteger(); + return (byte) readInteger(); } @Override public final float readFloat() { - float v; - if (!bigEndianPlatform) { - v = Platform.getFloat(buffer, offset); - } else { - v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 4; - return v; + return getBuffer(4).getFloat(); } @Override public final double readDouble() { - double v; - if (!bigEndianPlatform) { - v = Platform.getDouble(buffer, offset); - } else { - v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 8; - return v; + return getBuffer(8).getDouble(); } @Override public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); - int start = offset; - offset += len; - v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + v.putByteArray(rowId + i, bytes); + } } } @Override public final Binary readBinary(int len) { - Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); - offset += len; - return result; + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + return Binary.fromConstantByteArray( + buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + return Binary.fromConstantByteArray(bytes); + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fc7fa70c39419..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; @@ -27,6 +28,9 @@ import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import java.io.IOException; +import java.nio.ByteBuffer; + /** * A values reader for Parquet's run-length encoded data. This is based off of the version in * parquet-mr with these changes: @@ -49,9 +53,7 @@ private enum MODE { } // Encoded data. - private byte[] in; - private int end; - private int offset; + private ByteBufferInputStream in; // bit/byte width of decoded data and utility to batch unpack them. private int bitWidth; @@ -70,45 +72,40 @@ private enum MODE { // If true, the bit width is fixed. This decoder is used in different places and this also // controls if we need to read the bitwidth from the beginning of the data stream. private final boolean fixedWidth; + private final boolean readLength; public VectorizedRleValuesReader() { - fixedWidth = false; + this.fixedWidth = false; + this.readLength = false; } public VectorizedRleValuesReader(int bitWidth) { - fixedWidth = true; + this.fixedWidth = true; + this.readLength = bitWidth != 0; + init(bitWidth); + } + + public VectorizedRleValuesReader(int bitWidth, boolean readLength) { + this.fixedWidth = true; + this.readLength = readLength; init(bitWidth); } @Override - public void initFromPage(int valueCount, byte[] page, int start) { - this.offset = start; - this.in = page; + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; if (fixedWidth) { - if (bitWidth != 0) { + // initialize for repetition and definition levels + if (readLength) { int length = readIntLittleEndian(); - this.end = this.offset + length; + this.in = in.sliceStream(length); } } else { - this.end = page.length; - if (this.end != this.offset) init(page[this.offset++] & 255); - } - if (bitWidth == 0) { - // 0 bit width, treat this as an RLE run of valueCount number of 0's. - this.mode = MODE.RLE; - this.currentCount = valueCount; - this.currentValue = 0; - } else { - this.currentCount = 0; + // initialize for values + if (in.available() > 0) { + init(in.read()); + } } - } - - // Initialize the reader from a buffer. This is used for the V2 page encoding where the - // definition are in its own buffer. - public void initFromBuffer(int valueCount, byte[] data) { - this.offset = 0; - this.in = data; - this.end = data.length; if (bitWidth == 0) { // 0 bit width, treat this as an RLE run of valueCount number of 0's. this.mode = MODE.RLE; @@ -129,11 +126,6 @@ private void init(int bitWidth) { this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); } - @Override - public int getNextOffset() { - return this.end; - } - @Override public boolean readBoolean() { return this.readInteger() != 0; @@ -182,7 +174,7 @@ public void readIntegers( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -217,7 +209,7 @@ public void readBooleans( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -251,7 +243,7 @@ public void readBytes( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -285,7 +277,7 @@ public void readShorts( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -321,7 +313,7 @@ public void readLongs( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -355,7 +347,7 @@ public void readFloats( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -389,7 +381,7 @@ public void readDoubles( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -423,7 +415,7 @@ public void readBinarys( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -462,7 +454,7 @@ public void readIntegers( WritableColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -559,12 +551,12 @@ public Binary readBinary(int len) { /** * Reads the next varint encoded int. */ - private int readUnsignedVarInt() { + private int readUnsignedVarInt() throws IOException { int value = 0; int shift = 0; int b; do { - b = in[offset++] & 255; + b = in.read(); value |= (b & 0x7F) << shift; shift += 7; } while ((b & 0x80) != 0); @@ -574,35 +566,32 @@ private int readUnsignedVarInt() { /** * Reads the next 4 byte little endian int. */ - private int readIntLittleEndian() { - int ch4 = in[offset] & 255; - int ch3 = in[offset + 1] & 255; - int ch2 = in[offset + 2] & 255; - int ch1 = in[offset + 3] & 255; - offset += 4; + private int readIntLittleEndian() throws IOException { + int ch4 = in.read(); + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** * Reads the next byteWidth little endian int. */ - private int readIntLittleEndianPaddedOnBitWidth() { + private int readIntLittleEndianPaddedOnBitWidth() throws IOException { switch (bytesWidth) { case 0: return 0; case 1: - return in[offset++] & 255; + return in.read(); case 2: { - int ch2 = in[offset] & 255; - int ch1 = in[offset + 1] & 255; - offset += 2; + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 8) + ch2; } case 3: { - int ch3 = in[offset] & 255; - int ch2 = in[offset + 1] & 255; - int ch1 = in[offset + 2] & 255; - offset += 3; + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { @@ -619,32 +608,36 @@ private int ceil8(int value) { /** * Reads the next group. */ - private void readNextGroup() { - int header = readUnsignedVarInt(); - this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; - switch (mode) { - case RLE: - this.currentCount = header >>> 1; - this.currentValue = readIntLittleEndianPaddedOnBitWidth(); - return; - case PACKED: - int numGroups = header >>> 1; - this.currentCount = numGroups * 8; - int bytesToRead = ceil8(this.currentCount * this.bitWidth); - - if (this.currentBuffer.length < this.currentCount) { - this.currentBuffer = new int[this.currentCount]; - } - currentBufferIdx = 0; - int valueIndex = 0; - for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { - this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); - valueIndex += 8; - } - offset += bytesToRead; - return; - default: - throw new ParquetDecodingException("not a valid mode " + this.mode); + private void readNextGroup() { + try { + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int valueIndex = 0; + while (valueIndex < this.currentCount) { + // values are bit packed 8 at a time, so reading bitWidth will always work + ByteBuffer buffer = in.slice(bitWidth); + this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex); + valueIndex += 8; + } + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read from input stream", e); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index f36a89a4c3c5f..9cfc30725f03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -81,7 +81,10 @@ object ParquetOptions { "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, - "lzo" -> CompressionCodecName.LZO) + "lzo" -> CompressionCodecName.LZO, + "lz4" -> CompressionCodecName.LZ4, + "brotli" -> CompressionCodecName.BROTLI, + "zstd" -> CompressionCodecName.ZSTD) def getParquetCompressionCodecName(name: String): String = { shortParquetCompressionCodecNames(name).name() diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 51dac111029e8..58ed201e2a60f 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -89,7 +89,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -122,7 +122,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -147,7 +147,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -180,7 +180,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -205,7 +205,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -230,7 +230,7 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 -Partition Statistics 1054 bytes, 2 rows +Partition Statistics 1144 bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 863703b15f4f1..efc2f20a907f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -503,7 +503,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 740) + assert(inMemoryRelation.computeStats().sizeInBytes === 800) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -516,7 +516,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + assert(inMemoryRelation2.computeStats().sizeInBytes === 800) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment From 6ea582e36ab0a2e4e01340f6fc8cfb8d493d567d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 9 May 2018 09:15:16 -0700 Subject: [PATCH 456/593] [SPARK-24181][SQL] Better error message for writing sorted data ## What changes were proposed in this pull request? The exception message should clearly distinguish sorting and bucketing in `save` and `jdbc` write. When a user tries to write a sorted data using save or insertInto, it will throw an exception with message that `s"'$operation' does not support bucketing right now""`. We should throw `s"'$operation' does not support sortBy right now""` instead. ## How was this patch tested? More tests in `DataFrameReaderWriterSuite.scala` Author: DB Tsai Closes #21235 from dbtsai/fixException. --- .../apache/spark/sql/DataFrameWriter.scala | 12 ++++++--- .../sql/sources/BucketedWriteSuite.scala | 27 ++++++++++++++++--- .../sql/test/DataFrameReaderWriterSuite.scala | 16 +++++++++-- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e183fa6f9542b..90bea2d676e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -330,8 +330,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def getBucketSpec: Option[BucketSpec] = { - if (sortColumnNames.isDefined) { - require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + if (sortColumnNames.isDefined && numBuckets.isEmpty) { + throw new AnalysisException("sortBy must be used together with bucketBy") } numBuckets.map { n => @@ -340,8 +340,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def assertNotBucketed(operation: String): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { - throw new AnalysisException(s"'$operation' does not support bucketing right now") + if (getBucketSpec.isDefined) { + if (sortColumnNames.isEmpty) { + throw new AnalysisException(s"'$operation' does not support bucketBy right now") + } else { + throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 93f3efe2ccc4a..5ff1ea84d9a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -60,7 +60,10 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + val e = intercept[AnalysisException] { + df.write.sortBy("j").saveAsTable("tt") + } + assert(e.getMessage == "sortBy must be used together with bucketBy;") } test("sorting by non-orderable column") { @@ -74,7 +77,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").parquet("/tmp/path") } - assert(e.getMessage == "'save' does not support bucketing right now;") + assert(e.getMessage == "'save' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using save()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path") + } + assert(e.getMessage == "'save' does not support bucketBy and sortBy right now;") } test("write bucketed data using insertInto()") { @@ -83,7 +95,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").insertInto("tt") } - assert(e.getMessage == "'insertInto' does not support bucketing right now;") + assert(e.getMessage == "'insertInto' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using insertInto()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").insertInto("tt") + } + assert(e.getMessage == "'insertInto' does not support bucketBy and sortBy right now;") } private lazy val df = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 14b1feb2adc20..b65058fffd339 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(LastOptions.parameters("doubleOpt") == "6.7") } - test("check jdbc() does not support partitioning or bucketing") { + test("check jdbc() does not support partitioning, bucketBy or sortBy") { val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) var w = df.write.partitionBy("value") @@ -287,7 +287,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "bucketing").foreach { s => + Seq("jdbc", "does not support bucketBy right now").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("sortBy must be used together with bucketBy").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.bucketBy(2, "value").sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "does not support bucketBy and sortBy right now").foreach { s => assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } From 94155d0395324a012db2fc8a57edb3cd90b61e96 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 9 May 2018 10:34:57 -0700 Subject: [PATCH 457/593] [MINOR][ML][DOC] Improved Naive Bayes user guide explanation ## What changes were proposed in this pull request? This copies the material from the spark.mllib user guide page for Naive Bayes to the spark.ml user guide page. I also improved the wording and organization slightly. ## How was this patch tested? Built docs locally. Author: Joseph K. Bradley Closes #21272 from jkbradley/nb-doc-update. --- docs/ml-classification-regression.md | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d660655e193eb..b3d109039da4d 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat ## Naive Bayes [Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple -probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence -assumptions between the features. The `spark.ml` implementation currently supports both [multinomial -naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between every pair of features. + +Naive Bayes can be trained very efficiently. With a single pass over the training data, +it computes the conditional probability distribution of each feature given each label. +For prediction, it applies Bayes' theorem to compute the conditional probability distribution +of each label given an observation. + +MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +*Input data*: +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each feature represents a term. +A feature's value is the frequency of the term (in multinomial Naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). +Feature values must be *non-negative*. The model type is selected with an optional parameter +"multinomial" or "bernoulli" with "multinomial" as the default. +For document classification, the input feature vectors should usually be sparse vectors. +Since the training data is only used once, it is not necessary to cache it. + +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +setting the parameter $\lambda$ (default to $1.0$). **Examples** From cc613b552e753d03cb62661591de59e1c8d82c74 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 13 Apr 2018 14:28:24 -0700 Subject: [PATCH 458/593] [PYSPARK] Update py4j to version 0.10.7. --- LICENSE | 2 +- bin/pyspark | 6 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../org/apache/spark/SecurityManager.scala | 12 +-- .../api/python/PythonGatewayServer.scala | 50 ++++++--- .../apache/spark/api/python/PythonRDD.scala | 29 +++-- .../apache/spark/api/python/PythonUtils.scala | 2 +- .../api/python/PythonWorkerFactory.scala | 20 ++-- .../apache/spark/deploy/PythonRunner.scala | 12 ++- .../spark/internal/config/package.scala | 5 + .../spark/security/SocketAuthHelper.scala | 101 ++++++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 13 ++- .../security/SocketAuthHelperSuite.scala | 97 +++++++++++++++++ dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- dev/run-pip-tests | 2 +- python/README.md | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.10.6-src.zip | Bin 80352 -> 0 bytes python/lib/py4j-0.10.7-src.zip | Bin 0 -> 42437 bytes python/pyspark/context.py | 4 +- python/pyspark/daemon.py | 21 +++- python/pyspark/java_gateway.py | 93 +++++++++------- python/pyspark/rdd.py | 21 ++-- python/pyspark/sql/dataframe.py | 12 +-- python/pyspark/worker.py | 7 +- python/setup.py | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- sbin/spark-config.sh | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 6 +- 33 files changed, 417 insertions(+), 120 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala create mode 100644 core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala delete mode 100644 python/lib/py4j-0.10.6-src.zip create mode 100644 python/lib/py4j-0.10.7-src.zip diff --git a/LICENSE b/LICENSE index c2b0d72663b55..820f14dbdeed0 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/bin/pyspark b/bin/pyspark index dd286277c1fc1..5d5affb1f97c3 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option -# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython +# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython # to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver # (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython # and executor Python executables. # Fail noisily if removed options are set if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then - echo "Error in pyspark startup:" + echo "Error in pyspark startup:" echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead." exit 1 fi @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 663670f2fddaf..15fa910c277b3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 093a9869b6dd7..220522d3a8296 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -350,7 +350,7 @@ net.sf.py4j py4j - 0.10.6 + 0.10.7 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index dbfd5a514c189..b87476322573d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,15 +17,10 @@ package org.apache.spark -import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import java.security.{KeyStore, SecureRandom} -import java.security.cert.X509Certificate import javax.net.ssl._ -import com.google.common.hash.HashCodes -import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} @@ -365,13 +360,8 @@ private[spark] class SecurityManager( return } - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secretBytes = new Array[Byte](length) - rnd.nextBytes(secretBytes) - + secretKey = Utils.createSecret(sparkConf) val creds = new Credentials() - secretKey = HashCodes.fromBytes(secretBytes).toString() creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 11f2432575d84..9ddc4a4910180 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -17,26 +17,39 @@ package org.apache.spark.api.python -import java.io.DataOutputStream -import java.net.Socket +import java.io.{DataOutputStream, File, FileOutputStream} +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import py4j.GatewayServer +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port - * back to its caller via a callback port specified by the caller. + * Process that starts a Py4J GatewayServer on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) - def main(args: Array[String]): Unit = Utils.tryOrExit { - // Start a GatewayServer on an ephemeral port - val gatewayServer: GatewayServer = new GatewayServer(null, 0) + def main(args: Array[String]): Unit = { + val secret = Utils.createSecret(new SparkConf()) + + // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured + // with the same secret, in case the app needs callbacks from the JVM to the underlying + // python processes. + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging { logDebug(s"Started PythonGatewayServer on port $boundPort") } - // Communicate the bound port back to the caller via the caller-specified callback port - val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") - val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt - logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") - val callbackSocket = new Socket(callbackHost, callbackPort) - val dos = new DataOutputStream(callbackSocket.getOutputStream) + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH")) + val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), + "connection", ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) + + val secretBytes = secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) dos.close() - callbackSocket.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: while (System.in.read() != -1) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f6293c0dc5091..a1ee2f7d1b119 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + // Authentication helper used when serving iterator data. + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SocketAuthHelper(conf) + } + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) @@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int]): Int = { + partitions: JArrayList[Int]): Array[Any] = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = @@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def collectAndServe[T](rdd: RDD[T]): Int = { + def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } - def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } @@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging { * and send them into this connection. * * The thread will terminate after all the data are sent or any exceptions happen. + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging { override def run() { try { val sock = serverSocket.accept() + authHelper.authClient(sock) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { writeIteratorToStream(items, out) } { out.close() + sock.close() } } catch { case NonFatal(e) => @@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging { } }.start() - serverSocket.getLocalPort + Array(serverSocket.getLocalPort, authHelper.secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 92e228a9dd10c..27a5e19f96a14 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 2340580b54f67..6afa37aa36fd3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) @@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String value }.getOrElse("pyspark.worker") + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } + + authHelper.authToServer(socket) daemonWorkers.put(socket, pid) socket } @@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) val worker = pb.start() // Redirect worker stdout and stderr redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) - // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) - out.write(serverSocket.getLocalPort + "\n") - out.flush() - - // Wait for it to connect to our socket + // Wait for it to connect to our socket, and validate the auth secret. serverSocket.setSoTimeout(10000) + try { val socket = serverSocket.accept() + authHelper.authClient(socket) simpleWorkers.put(socket, worker) return socket } catch { case e: Exception => - throw new SparkException("Python worker did not connect back in time", e) + throw new SparkException("Python worker failed to connect back.", e) } } finally { if (serverSocket != null) { @@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() @@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) - } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 7aca305783a7f..1b7e031ee0678 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.File -import java.net.URI +import java.net.{InetAddress, URI} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -39,6 +39,7 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() + val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -51,7 +52,13 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such - val gatewayServer = new py4j.GatewayServer(null, 0) + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() @@ -82,6 +89,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("PYSPARK_GATEWAY_SECRET", secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6bb98c37b4479..82f0a04e94b1c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -352,6 +352,11 @@ package object config { .regexConf .createOptional + private[spark] val AUTH_SECRET_BIT_LENGTH = + ConfigBuilder("spark.authenticate.secretBitLength") + .intConf + .createWithDefault(256) + private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala new file mode 100644 index 0000000000000..d15e7937b0523 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A class that can be used to add a simple authentication protocol to socket-based communication. + * + * The protocol is simple: an auth secret is written to the socket, and the other side checks the + * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is + * not expected to be valid anymore. + * + * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. + */ +private[spark] class SocketAuthHelper(conf: SparkConf) { + + val secret = Utils.createSecret(conf) + + /** + * Read the auth secret from the socket and compare to the expected value. Write the reply back + * to the socket. + * + * If authentication fails, this method will close the socket. + * + * @param s The client socket. + * @throws IllegalArgumentException If authentication fails. + */ + def authClient(s: Socket): Unit = { + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + } else { + writeUtf8("err", s) + JavaUtils.closeQuietly(s) + } + } finally { + s.setSoTimeout(currentTimeout) + } + } + + /** + * Authenticate with a server by writing the auth secret and checking the server's reply. + * + * If authentication fails, this method will close the socket. + * + * @param s The socket connected to the server. + * @throws IllegalArgumentException If authentication fails. + */ + def authToServer(s: Socket): Unit = { + writeUtf8(secret, s) + + val reply = readUtf8(s) + if (reply != "ok") { + JavaUtils.closeQuietly(s) + throw new IllegalArgumentException("Authentication failed.") + } + } + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index dcad1b914038f..13adaa921dc23 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.{Byte => JByte} import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -26,11 +27,11 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream -import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -44,6 +45,7 @@ import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.hash.HashCodes import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -2704,6 +2706,15 @@ private[spark] object Utils extends Logging { def substituteAppId(opt: String, appId: String): String = { opt.replace("{{APP_ID}}", appId) } + + def createSecret(conf: SparkConf): String = { + val bits = conf.get(AUTH_SECRET_BIT_LENGTH) + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + HashCodes.fromBytes(secretBytes).toString() + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala new file mode 100644 index 0000000000000..e57cb701b6284 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.Closeable +import java.net._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class SocketAuthHelperSuite extends SparkFunSuite { + + private val conf = new SparkConf() + private val authHelper = new SocketAuthHelper(conf) + + test("successful auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + authHelper.authToServer(client) + server.close() + server.join() + assert(server.error == null) + assert(server.authenticated) + } + } + } + + test("failed auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128)) + intercept[IllegalArgumentException] { + badHelper.authToServer(client) + } + server.close() + server.join() + assert(server.error != null) + assert(!server.authenticated) + } + } + } + + private class ServerThread extends Thread with Closeable { + + private val ss = new ServerSocket() + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + + @volatile var error: Exception = _ + @volatile var authenticated = false + + setDaemon(true) + start() + + def createClient(): Socket = { + new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort()) + } + + override def run(): Unit = { + var clientConn: Socket = null + try { + clientConn = ss.accept() + authHelper.authClient(clientConn) + authenticated = true + } catch { + case e: Exception => + error = e + } finally { + Option(clientConn).foreach(_.close()) + } + } + + override def close(): Unit = { + try { + ss.close() + } finally { + interrupt() + } + } + + } + +} diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f479c13f00be6..f552b81fde9f4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -170,7 +170,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e7c4599cb5003..024b1fca717df 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -171,7 +171,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 3447cd7395d95..938de7bc06663 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -189,7 +189,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar scala-compiler-2.11.8.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 1321c2be4c192..7271d1014e4ae 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do source "$VIRTUALENV_PATH"/bin/activate fi # Upgrade pip & friends if using virutal env - if [ ! -n "USE_CONDA" ]; then + if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi diff --git a/python/README.md b/python/README.md index 2e0112da58b94..c020d84b01ffd 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). +At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/docs/Makefile b/python/docs/Makefile index 09898f29950ed..b8e079483c90c 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip deleted file mode 100644 index 2f8edcc0c0b886669460642aa650a1b642367439..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 80352 zcmafZ1B@`;wq@J4ZM*wx+qP}nwr$(CZQC}!wmtuSnLBspB`;ISE+lnMcBQhk0v&QGOsvq7S?-^UssrKY{S?;d*)&wieEMdUW<4|230^y}AYwy4HW6Nl6u= zoaJ5-LR@-QR$A^vl6rDZB|J`!nwFAQGA2%Ke42Kgo=QPnnre2Al2#%n?2d5qNx#vg zLS!W4-9uaZ{$8F;zKm8{9bO_C`oE>MI*Aq63km=LhxLD@WoTezWpC%`{QpU7MO`Xx ziw&XoQw?Flz@#5B5@*v;0sWj>WS)p39@G;8D1fR(qEen@k{AUt3$zi z;{dihzth9IJM2krlN>HW3%ZU@f)zh#x}qz1I%iKBe-(PUFCT?dhT0aD|MeuiGxy4R>IH0`W_ZMf4)j+w+$g z&DpJ|DXf0EK#%-2Xw7m&$!gLmp+z{0+|j)9=N-Q-kysDt#|9;^#5;m^kYh*zh&4#| z4df7CH$H;4FfyEB7=l^*M|dV=3w8Kc42583#*=jtV%zL*ueu&U2Oo z&)ds{ur9liY+7?~5wuIlOP?FV65|eTeC@3ZEeknNdgMEe>Yr-~rnGgV#w^%xwuzI{{u3=pNbvv`QwzoNZN_kuvWq3t?@y$ z8MEXqc4>Y2dgo}^sFQb_jN``W;U}$r=9)Y3S(74`h~W{;EmYZKnY^q{u~4q{4VQf~ z79KOih$5^^(v5a|aY3h~Q(t8JFvuJMhg%~`X$_cd@_nb7q3Pt3mWyq?5X=1Cj-+d>H*(Nq>;;*OGB2~ z6}$*}Tnw0Z$!e_-f+JI8IeHLiZWRdIV|J<#VaX-l9)ZP_)z(gWnlBX7X-&o)PRc&< zKOTvBXrbPmD>-L= z`cZbtSy-iQOV||H$!Y_4TiFHsY&F^43qL=mKdxUEPMtU`d(E{95DdDFx%`kp<6>h6 z#S4|1euLb|Jb~|)-B;LQ%_k-gCC-eYUZluQgEXN2=B+z*FU4ld<*LphdGKbhOaGWD zJ*gL?wLTtS-$re89X)Dv*)w3Y=s!~ySv#&?=eaJ+n$M`ks-Q7b-^}BK#+ELLM*q4? zzks;u+q55ct(g4>p8kby<~YP%=coVxzPbPaDF22hBWnv2TW2Q|M^_Wae`CSF!0uo1 z-@xu#`y%P6HRZRq4`{)gA;nzTHCwqvc|~jGVl{f}!l+g2NTZH@q<}!kx?S`GNk;m;nclK2jw~ zLSc`rMZB#`+SWu@jR>*IxhIv>J77;TLtm|dR+R}ET4PTt10Lz#=z$QCuf$}BB%@tC zwatRA;3~_2F&aaf+Te*<7ur#?_x}v1l{r2iA|ziF@?!8en1 z0+yUUz;*&d6itxsh^nKNH%F=JwVFeo={PWkG99)Ej#hIr`S{v^z7Coe;@DNGtt95y z_)cN9NK%NtQ!>4}8q1?h5+aHZ`(SP9qpoAG zp|pO?h_NlcHeZ^L?@WKo@Xbs#0MhQ~q!?NYyOrcBY0sUy6RXp(QPdlD0?Ed>-U7(u zFB51^`qfQV+$imO-#`mYU4RnMoxcI>XkFPQ%(Bm#W;-?yH)njvkwS zb9KgUYiobmy1eW{%z_!8F=R9wi)_mqCoE8czP@cTC(LyM2ypy#DnqSW@i#M#x|%?EQ7L~}E!QUlDq)X4OqO7?x5 z9sCQHB{GA$z}DKWKFje3nu5Bv1+45-_3_2QJ*$^sIYjTKZZ~4;-}~~q*MQ$GjcOuX zC^TyH{mNsx$Nb4$j6J{;SX>b6lrm~!&7X~;@V+H~ZXsNRS;Ha*t1VXq#P^*s%|wt$ z6jU%>%TLBW&Kt)Yd~#`+FzrrSDP9@Xsh-AgS8)uRU={?oBCedNqePLlNcqXYzi$i7 z%2}`ZKc@EncP;Tu#(PgSNT<9E#W>*dMSoZQ^s-Rwp5Kd-h z-XWRydEzVVYG!ZK+8jO5_s~>wf+FX$_m29Zu+IP3U}LytElYKIfk-2acnT}62p3uw zv+t9RQ|gpS$?JU7w*C>bke}#Rn1cG=88$+ygmnt{ei-`8s9^}}LbWE5gb4F~l-KwB zCTSRa<462Kdc<>S(4K39~X>J1BqRJETL>by{2Hkqof4uyPm) z24-Q)&~AJc56Z?}XSFQG=bP#Vi7e;ErG}p2mu;LG@Dww*L zCwP;<%9KE8qz+f3AyB#6w(TMnJ|Vz?d9)(SU(op?8ead#DMy{;eYsSc;1F#x*lJ!C z2CR~rVu}{{HT-vnv;YQ6^+2Mb)9I^oz?f7b z17Iqfu6OUnRtXUWj{WcV6cbBNHwjTYWK4sh3atk=sFHQg>8KWmWO+&(HL#RK3nFJ6 z3(tj#_Yo-2Fa3HCuSEm&)CgMyfi%WHX0c?aqR$A)zb(S!ERUswrmB=4hb5?~Hkl0b;BN8)z0 zHg^Qs;|E>fZn)76JSB~DVuNgTmUfB#IqMeh z6vxngU2+c`kU4|#I`@W-iepJl3|@sN7P8zX-Pnj;akhHD2ZWwKWKU96v#dWF^cItU z%aS4PI>Gk_?AwvoAcbdU;83eb{q;5{LZ+Yez-lM#BO77aSw>%gSVP*LiSa=?`OhR6 z>Yx32=Jyj2=2ch8f%usJ%=l9_%I0$VyUwJ_TtFHV5AYsaL}fn%T zb7qzq5c42X49xd&DK~U-tqf0iQ@q`hQLK|iRzYzP!`ts*_GvTSFZ zH{BpjK;lBcWc4L{l?r4WXN~BNrG=KdD5fR;-mGYN8(J3Gr zt5B8|oMeIkWsEorBX$M%VcnsBgpR%0G_M0cSuf;tj*0v3B?XpXt zG}z7a+yG4#ehaCMCx%&G9LInwL!GSXQk3jY&)B|SKZEyM4ed6!dF|xK!SDgXrIBo9 zT$|HJa>vLydtmJHo^b?$>XSxd-4@FB}`L(=h^|tk}qSaat#V(=3aGd`WDp zy0f^*mHCrN0Ws>ML5{1X-EdZE7(sRG!86)6_!Cbtp+4iIeshH22(O&+#jA5)w4=V? zyBT#9VR+shhEWFwY1d8Ak#Cqfwy9yiYP(7%>IY9t;gjw){ukOIK1A&rr=9yx zBsm7GZiWRfdF^V7YaB!)d?i{Oem8?C2WirE z-_9scgUE^|FZ6M_n;rS}e#^bD;d4~B3jXFGye&AQSKut`<)c85S2i8nXB11FW%WU! zt|b#aI^_B|26Bohn`1G&kfxc{->1sIgOeH^uiU+YoJ(k>C+#*C->J)Afj^JXH1e#~ zTz~$kh4P+q44JH6tVhl<%&#yU@7t2mi=3FL2%5R@+;yr$s%pz{lWkhmD1D^J&+ww8 z-I6|{nD->^BDl%8SZM`#NI~_}IGeHAy!@LwVmS^6>nGe z3@cH_wsy`JZEuF&Tr7PzHLt*P{jASW1vXE8wMy5yfaF2L3{|XTd8<`hOQ7sfNA<{5k8_!YBj_7H`zuvV;uW7hTrtL_|DYLr2KQc#5WKD=BoA-g zFUPrd_}TR|9w&o8?OW(yB~dn~3p+=Hh~_l$graf_`UqbS;y>9%st6lM|RTj6uJGEp(Lhp}dmhT`Jz+OIkLg+f=*N8J$NAG`k?&t0&Rz^nd zPb|DWycx3q)@fAre`I=m-_s1n?e3(zt=;1Y`w+@PW+T39t&}UUZVfxE>mB{zw;Px4 zy;WQr6QfL#VYgzpi;5??@05ncs5S{rsao)2gM8KG&c0K~g^=9Z@yHYTqS)=T9Xs zvQi3j^>TD0jxz0zmngwhHUj{aoV@J-bSKE2Z)4UL8$>)8-%+zil)oOgez9fGUzBh! zB%(jAiHaXO9qZMibWBFJ4rd7?Gf6KY&T%)JR9RI9)RU>;2BEiy8mbh?YGxVJU&nl= zU}A!j$0kk9*4EV4?pI8_nY*}oj%m_?WQW6OC3JU?C0Xq1vgktW6;e<2clxCDd|zJ^ z#0FFM8IhF*3Wx>pY>t=Q34QyCe4QVh?R3Xit_kWO#N?1PSJk6+jBf}dA} zRCKBqRnN8Kg?>HHebxtOauS0t>aYxYoeUMsbpa19Tc0Z)Dr*FelymGA%mF$n#|wwN z=Y)g$@MQJ%^hVR0va^rXnHA+m*ds+*d_C+u-Rf4|UAumr?|g~7z_U@I^^IO!I%E!3 zGC?dBJW8|fCMcdH9X&d^U=MD)4>nz^V%q_(8n&o-)kD6-5^j$93@_nE?d~jm1#gG) z3?J~+8#$P})Wy~PLInnuDs4+!t)<2`43iKPp*L3-%q0nl&dpQD4 zVr)getD`+z5c($5Cs+&HJ$7viXN9o?zdj0fC+%UVAvxq78gZ=aIG1qh@c5g3@9%@6 zMk$G;^VXa-e&3)Hrl=>6~IE<&={1?0W&q7cV3WYMmHlZ?7wYpO$4TK0Us<66 zD!x|x=1%QiBAA$W_}eit8GG$AF-~5(S{+V-abENe)~KXw1K=B<%wv&uz+(qMwDjPz zfWOB5r8Xu}Gx|QiF0hk2SmCY^(#&GUOkV36XK9|xrqcMr1SSV$v*Y{y*0Wx6AAP+| zYQRKvJ{`IS1)7P6bL8uygFE_r=C;2Wlw-rkr-}^s<|WVDf>FdlmDCXkfHi^`{mIt- z{FdgAav&!}m4{g34Hx|2u`q!gKY=Vik;FDBCq(=$e(Y2bUq*<4Z1xoI1@_?%CX^S# zxr{KN-0kY2YxQ$c~tyfThY__XGrb&DzM|)YA(yI%w+a+YdZ#( z@9T%a?XiH8r~6tSsfJ5y=+MY_7{ptNMXx`c!s@fSz99Rd1+-zb+gYoQ-13DCSr0(swB40%H{K6nKqsta;wwM zB0BVw5CYXJE6ltzcbHf~h9aXS@;qy2AJgf!(k8&Uo{f?zQm3=AILmZECaOgR_A&T% zLUzUH&$s64jjze4Ekn_{L{5SIZY+|niknyt-37dQ&~;kVCA+YOoBg)h*yUCMcf#%R z=cv7gw^s(W_ciZ ziELhk>M-5fV}wPqe>P^Ddf8nqo-VGmwtJh(##=-2wlrO#6Yg&Q zXQFv(Dt~n=+|(;lEmvhdpk1FN^ZZ{dvov^Z*ez-0uNQx?jM^68nb(v)MP^@Mz<=TY zU)Go^*?s~87yuv<@*f59|FFh(HueVpM?3r<_~I7Tnb;);gx)hs(H*HyD#S~oU@A}b zf)HRNpnze(8Y1!x>z=I5ifij7^dnGr#!*+_EL-&W^MZ!H_tUYXIR`p|6ijqC>T>Cy zLm7swnLsy3Ev6x4U3(}0ZlsC>Mf35JK*_soBQQwYbz*V?fhIAOyLWvy#eSyBfC!k{ zLQbea)G7!S>UO&NQ25I=xrT(1Rg;O7Z9$x~lnRtl#ycCzlm^IhqG(EXD%Eq)df}-; zCJZl81QuyJ=ZFf6vX6Wp=_ZBoznPoydpHOnapf$uB&S4a^{sFV2!;KBU;^-~PH3yw z5a`hFJbR)1ed&Ze++FZVX?*Tn-e~p`c)Ue`?EwXE)cr$71Lkt4p&rZDleX#Y4rsOC z0WzMM+gA)$lMVc!Dx82C$Foic;;UvKSBAy(Z3aCh>xj)P(jz0dig#Ch=T+=%LW)!C zzKx8<(?xAi+)f9Kg)LulFs{vUj`ZUxN8qHei*lZ7u~*y|^$*40|N3H|5E9Gz*V=V5 z277Ko)JoWgMxY2GktT$~u9YPHu3-MYl9+POo~Yh?x5jdgTQCxLFc>hoq)`#SrCU>T z3CMclF?zer-eH0IRn!ZDZ(B=w%2o3qbz}_fbPw_P+0)EccJ6Nm=QRIHZuQjGI8yfc zM4w#c60a!Y%Z0ZniPg=*Wl|{OS0wzzk#gj3v^WcLmjIYGX7L$uo_MqUlIL(9>QOy+ zK>RATPvBdw4nXrC)YUyVOY|hqvA#wC#EiP-jB3%|f@$Je`mxXVdS%$9n}$YO+fr_h zadYV20oZ|_p}*tlpoeT<-5la8h;Q~UqDTQ_4$*Glzeq2{NCP2m3)>|6 zb{(!`9dazIA49j3u-VL=lssbJ=p);~<`wtDv0lzAiL&4mP?=!Pa*JneH`cqd`VJ4p zC_WW1K`AT>a`LN@R@$#Kc7l{Idn~K?YCFyUNy78KbO5@-cM16rxWoQWef`Tcp!87` zy!20fEg}E_ApX0)Of75;tSvnMBgy;^u7PW{ov1BQ1iv*sMs0jcbgg32c_NN;ohbnV z3&I}WltV2kVk@UAorbQ+qu{UGobPzjjfb6%wJ^a#axa#=zV{iJ53tldT0%9@v@DpP z?jDVwW0HxH5s?Q`)EfSJ)FyfT2~sS5`yo=~FJjAaP-ZZeOi_uUwPH;wnP2v8_2pxP zbQ*E!jAA~C4NS5LdaR5QpyXh7>7h92w}X*a`a}|=nwsR<{Y(q3f__i%zD8Kd(HOak z<;u{NN^6@chZZ1Rc1bOB(lR7_QVvklK^p&>7^IpJm3i^Nk{XJ7>K112PtB9pdy^)w zce)vs9r=!X){ZKfM3RC6$$Q+cOQIC;3n{d-eiX@NI!MM9UfB;Tl#D ze-*E8{iwgYT?^`kq8Z4!Il#coxrUN19xVZrsaui7e~=Al9zH!X-SOK{MPLHj<5XIK zT2Unjv!m0SO;NI+8~9U}IUrIz%y>m-U3Q^LKiA1z9PWZo)-AQu12XNXppo{HRz}K0*);O zph^{MDEO(Q4C(>VF7bq9CbzLrOkms^mjX9JQmva{58$@p@&nb-`rk%mnSu^^UPFuA}Jk2SHU+sCt1nZO;JHqxEK^rXX)% zVSfIzB1Ur{n3GP>ppnTyD?JO|kl6C*$ zLXb&c`d?T%#wEQV{ES5PLAKLv0*_l%uXu5T-6AlRv4TZ=S7d+IoIW<%0?Y}{(azb7 zJWfJ~)gT0#8Y}AwbR`J#Ag47e%uF*Lv1>xXBTio1?-UsV(FjN;d3gwNq08wDPq^8T zSc473#czC0?W;&Em1uQqxq}B;=wrRE4BZAN?>YuWDWph7?qCROlG=_S7c|)Zy@|I3 zE+ZCmxQK0mWr>u!stw+sc|IdN*G(4lD-R0!V*laS3FKasDe7kB)7E*|@N6%9XU z^mJ90I{Cy?h6E#%-fuD=8?<5GQlMfWi$6*KL*=_rc6;UH^|yBf_s@gdi%-_@&l9TG zaUiZ&=b*HHLT)IJ{fR~9?|<{|+%}|ns=xvO0Eq$sVElV^urzQr&@-~Lwl**@PwBM)bb!17sU4vUPClTMJd_zr(4rZClsb zCQB+r>U76UQYPG?w@ozQhfkWvpbYhEUztOh?J4Bhcqo|p{(Z&Z`n0aa4#H3Q>5z!W1)wA)Csb}b5HAFY2Fl2B(#IE zXk=@Oan0JlPYrw%Xi~aw#08;uc7U9vhiGqS+7E@}17Hm}bx)n4cdt-e#5RB z$2@GONPYgX1DM!=A#1pTZ}R&I59H@JL{mYItp|XR{`!DslA|*FWCU^1aZecY+9F+R za03z3MIHM1+VA-K1tb!#z|?@9&BD~x%_^KGe}bv1F=PXIVWO-ng88R?^ntF2Y)^?u z1ylz?R0T2z3?=G>YqLb~f$>WECj-*smfFJ$`6G>2jX>uXiWxZ+K6x@p8pLOw?v*x+ z`&ipdLl`Hsfxun5;2~ukpd|pOX}k)?#$T#fI-B*0cP=~zi_A| z7(B&mB4T5bQGKu|sxwL2lK@O70}0`{e&bH?Iegcw>|9tGtasyb<8qH>_GR4XT!h)9xO)w-enyzW0DaR9|N5DI!O z&sz~X5y`SPkLV=S`QYNBw^d;!10@$ovQ&x-O}&?b>JmjHcSDrM=y7KQQb20+Ket7( zjla{D)`Xr}f!p~Dk1Z)A_Ul%(03^>8ZzcsSMWzQrj5-C#R|N`xeq#dz%Ln94_6R?^ z6_j9iP6`0uxVi?*8SZrZN|^fZ&~;Tey<%Jw#G)96=@8j_>JUbVvkQvpjs7ABj4`23 zfVD(WqDTGJTxRJ-F8}BU{3C?$0^9#4G^DC_$C09yJ{|ufR(+L9eIrjC2^Fb+x&bPUHE(%vVSoybxsasJ|7Ai*`YCY7))oTGFazTEg3%fB7!Lr)D*RVzhKWW zsDo@C`#tSzOav%TG+AF|0^r61Bv2++6o3khwjr6= zu+t_A$(zGg2-56U)f~@v*PW2*Y58RN&dO9yhaRTSzq1jJTW&5t5@SI_&b>EYMwmZY zk(Grx;IVvcJ166XQf;x>w_O99z11AhRxaRLRck!9K2S^G!Md3Md!Vwtg_Mg}54w*Pr3p3}y#nCS12~LxFY= z4I7qvXcO7^@OpI^d}5! zw4!L#`vEczEn6u>JyACk(aGev`CJl|A)4oax*|C+Kwm_@3xq&+HCm2hO6&BVi0;se zqCLNS5>)M!shC~kCjw?pseaUxk|1Sm_2XmyoYsFNah7bcqTPiVoA5^%Og2iBza+o(ceM%ENcj z)Kn2);Ssx;nb=vM$eEZK@XME+oK@P5oVft!Zv>}2q8w(r`cvr#r;>k zYar73#j=a9Z8Bc}yl6;N(NsdoaYZ*`2Wv2Xnh&Ku1TD?{?^=I_r#CW1hRsNJ-me%_ z1i2{}$kh=11&1a35xksSEO?xeF9;N>YkD3P&$s%{F~Q*%SErd!4saA#YP|poj=!dg z+VRvRScT09M9{~FntlS6wJ>N4#nfFsE8ds>A@LJqX{G4j0|%X2aQNL$3!!{OYj`aG zWiVsY76AL99E4C5wRp!@^R_^ctyHoW8T=Ul&ixl$kvrHqib9HNQ$olJ^n-!QHqmhu z9+w}Gu2yD+2f7m-ed{m!dSPD-BZwpnze=h{WpEq;N*Zf*8g)L9x`GmbAmel5HX1C4(7tciiCf} z>&UdHzv82O;fEjjt?j7VU$NOx^p+@Z7Y{%%9m-}filb^48B{`|LLJI&8OZsd9r_W0 zO(V*fYfK$uwB`D7%QgPixX0p;mKMt}{0XD@ZVS`}@BoHU$Tu6>U74&jY!mxw!H1)Z zZtVP=xY>N11Ux+Z_EgJCcS?_t+EO`}x>N6M4zC>H#5b1`I9{M!`Rym`#W%zZtqou* zY4pR-w!iyhuY_F@2FZ{*V$9v|y0tO}QiaSea<3n+%qkpkE%_L?F{+7!@MIW4E&A5k zVcIbhCK}ju|0)Rc4)Axy}V+TtX)i;s(m4nwy$rqAxm5FA|4DXKQqtHMVXm;H$wgtbtM`w}w4es~`EAXwq_1mGEWPABXYRpB33mwlG#kd#t+< zdG-FnBF7>lx_s@`yOb-X82q^8`jIS^D8$?Miyq`e2SgQEAf{pvfiTk-c6t5zJP zu-Y$E1u46oUs^b)YeaN;RAQ3?a8oUfzr8N$sn6zWhdRJL#Q`r`C+yeQSt0BymRf0D zZySZjE5_K9y4TG!CG@_n)PDDmcSz~&_ysVda!mmiWP4KwqmFMsVlQRnLknmnj0qdS zpSNevdvjO#XWc(mG?6p#!9I{f?8``R8+2V`a5)D_s;-Nnk9=|+>tj|#)Q zM&b2!8rk?9$|5&I%Ltwc2frt$KFA#0PPUASx;2yA8T}_!=B{9NlGb%uQaK7R4)S3| zc_kD32`NNqnoKM%h8}5Z)-h=@q(X7AXR2Z=HlXd8>oU&^%9^w0NVb!nW#U|rcCs|S z3s$UC&}hH9Y-;W1@O^lwL4SGUshhfAOqY4wY7erebY|Y%-Ob2D?UGIux%F*p0GVS#1Zt;^8KweJAqgtvpq*K6X5<)C(?Jql zE7Lr*aqqXCIeKqh>jVB|vRh-}qt;E?D|zK_b!RlCh%|cKwT|zw8$()rSngJ}VxvtP z#+<+RG{xtm=bo>P>Nv{u5HG3m0o3Ry=%YF;{o+r?l36*?-qjf~=BFLwNqrFgh|NT(*6uKlx_c%~+G3 z3%&*~wd;2M(xyBIyo|l$<`0W4D;K;s6P7k!%&zvXm-FWt9b3hn-+xG)e_3mJYA4Ce z! zD$#2jTmYU@p|5=n$e{|JMjUczXuAtDjoDLq*mf5arEniAPA66Gb^NCNmA$Wn5_ zsnS=iEafNB_>NP|rq`6%b#i6F#)#{w8iWY4ue=5JEVX}+fa=zO&FDcPkR7=m(M5b8OY zfK^A&nI1HqbkrG>0q0@o%ppwfP4e1ibRfJ;dChGTA+!io%&W zf!Sq=$4REv393fQ01br3gIQdmkDAeoFGD$~JiFXyB?oZ!|5?!4M9&peF-;>D{%F}|=u)yiOsV&tPrVmH_*vI@v1 zg3x!56aSFhigdY&HmJU96J5=)VddU|9GtOc4a&JN^bG}^l`$8s5QefN6#jt*_Pdpv z;GdjKQ(n*;!DQDbSiZo%@YDMI+vCn1ixyq&KJSgbNHFr>l`!48*WZ`5Rfpe`h35+f zPEOw5OL(xm_V1?+zu?ciy`@$kFAoRSV6{2lx1Z13ouRz2I{Mk7sj8?v-xn=zH({@y zpA!bG-0Wxi+Jd1Ag}UF5!>4w?kfU=|+vMGai;JkD-MEE(zh|_!YPT>tzGgZ<-ybrp zFLy`GXF6HA&z^&ckw1(*&mvh{@O8BcyNW6UpD&sz*R|$#_j!>&l&Y2AfXz!(Du79G zp{b4Zk&{$YPWEA8o3{G7{TugY`Dufbx)d-Fn-}0b86-x3IxfNi5s_mt*30Lg`WJem zkn`&&AB|;V{a74n#fS-`-l4QTRzz-b%l2oMB9Jzu(*ZqOSrM>-0QXHJt$IluOEXRf z`+lum)TL1&%bYxlfQ=4!YdXfS_SpxiH4-FZ4t>UUC581j{0G8#pQ6dT=;Qs zk^6ov3Lol{KCaG|Wscb}-P6OS36nmd1bSUjJ4;^AAQFMemGgqDkQ zz>h-~x2;V708BJCezX=6qOxr9DU_p)ItkK%gh7s)17OTVS*M7xyD_+9o+H5Ld{{vC zs|k7G0LuF@p`vM}8Wt9?k|6-Bp)*u_NU;d07PD#(hxeldddY{1(w|s1@|&}9c(^$7 z9xDTT0w1RwJNF|we;{DRVxZ3eP={53r2nFaCt9qy<_DYjPkgk%9#zO$BF!wm0J{3i zXUs)?F}J8lRJ6VjMf8izApM~)C<>Y|xUKYK`73tz?G_I2=}%LrriQ+jmZu>;>YbUo z>L#IXuTlvC0AJn+6Rm;h?p6ubtp#OC=`59Sg5z$Q3^7pF#R2+B*&<8-H;Asn!X8oQ zxenY1gs1dYabnrnw5mAr5yU36ItbGrGnq00n;LK`?vaK>t$Sly`sMhi`61Q z*9TB<5~w*$0aKrj6itE+Ebo!QSb_%+M(_&9(#O;BfZG+61TIB_v=gWX%LnGzAj}4V zD&Qms2W#DT+t1!jA24~fkIf-2zbVD!j~K<*)sNW|do1Z4yHY*> zU6FoAh^A@V-%#nx$E&5eYro}Cc)(#AlNOQGS+uk})T7WsM#oRA)aSZ6UT9jKH~m>l zO;ff%0#FBP{8+a{Rk%fBRr=(RJV;}jKm;GMXq3_l{NL`Clof~8jBoB;(Y8h|YbzJT z{3(O9$;h6gfaKs>DNfyoBl!vUn$g;?h?WqZ2#ERg@xA;GQuzhVVqD`3 zhas51{QBX{d+|W}is2aQBZu9p5A|{^_}x_S+x5}`8qmrd@LeS4?Ujcl+#$lO_l#}?if zH1xi@yfAaJ4d}37YxP(*IfNX6N0rXEe#ACI^$mIS2(2lt3V#Ze!8mJHP_Ju}l*DDP6Hax(>OAHlcp5mMRjKRJnrQ(0%@)9j7c0 z4pI#ei7h9}19J5sRQ#8qM)VrQL2w-`$T4fKI;F;5!ohd_^pipuP2=GU#Qu>5PzBvr z*`U{`>lx;xAmzy|UAI0B`LNiN0WJlyO1b(xL6qce*a8Wn_J-m#vN!3M+WSXN4l6tW z_sIxNM1(OzulZS^hRDD87$aJLKy1ch?Jx)`diw`Q<3rOv)$4Ow8-OOT*5R0x*DTGV zCVLauCPQdY8tP>bTM%)_l2jD@0dG;M6dIMEbM^FFw1?URXAb51`@f6$kUX-P+rtQi zQ@h)V4lo*RE7NtL&?X$g3LN1Html6@gNLs{xr5-(qs${!$-2 zkKM>6y2Dr-q&_YsQMS+VT(e1;uGUS5yl-L)E*MtDEEZow)=YVipgFw05a6NpDO~#T zK=#E{;yzBUJGmGZLO6(ZfkIRF@enlQRS`_0Q?5tUs>5jlS-) zLCYLd<@M43+p6H(HNS351NF`T>lUAB0R#>}2Moz%iRVjh6`(X29A(tq%wKyMY!06;3J0DWe!gm5~4?Wt*V;Huf0Ne_Um26YPJRTkMY92ywM zKjp^F44rbke`KnU)u|dO$I3C<;@FP?hR6`VKQf$x#DZxf(eFbMzrk`;>TvO32faj1 zrBfn!0^Kq^BfW(72p1CpLvyy`v45AGYqCu-)K;KsRyl8M(jaP9@STtQA)7jyVMG^# z!8^0aQ>`px!f3)0>VYE7?YzdydI;Vo35y9S<;$joVK*`kI)}K*NwEZQkBZJ4-m_Vs z3Uy7e7TjzbHqxvof%|ISRY3Wo3SO?Efx>~w3@EB9u9Hq^ANgtamURx&a+I$t9;VztpKu; zWTZ(nlXzN#!p=z5S(*|m#LHmM?L;?&i`I|UygZK`#jmE!DNEl?chq>`Lck`oLVH6I zdI&U?;FkXR5E+_c^QoksE9j^^>=A<9r1)j1Q_?(}IN9v}2GBtOE{GOof>ahM?Im>W z>;?t9dsG!7YWGDBWrKs4)lXNzZbY7eH6o|0S`VbUZ~G(=w79BdE!QS$(*atenHC>0 zj>O|N1X2+k?HZ0sQOl_Wxop;15vmAMy3YH$F zfqIE-Wu+IPB_ib1&&%oQg~HRvVbzZ*Rw=NNL+++Q$VV9)R4F~MMjbkeS}ZC-ZYS*q zKwvnI8LZi-joj=$PA}>11zEjpaxcAp3lc9F5bh2;YYdE1n3{v!-|3d%oTEQb3-*w- z8thqu?orsqg77V%ObvBH_cmg;v3+XsmWo$NrTSrVfg|IBQ)q^JAjYv_H}O#bj~e+2 zKs0fRHl{3SU~?=-H+VLMfj>@kBrfFXR%{P3X>eK5phZS9;(cn;%UUjSW;6g z$JAmKj}vN>2kO;0OTDwvg^J-(@0D+>GofIaQ96fHT`_ZquO7qhF*2ZXgjrRjWKu&a zv^>r0A~L8-wYSX@5cg)o)wlO0;6Y*(@2)}j9PN+?|8Svn3>)_vUetn@)4?|(tPM1+ zhsV4XJl{dvn4M~8>PWlADm$~(L`B|lqe5l2%(6QZoby$X!h`@`z>T=N&Q@u)`EoU3 zJvc=D6XW5N)7-yxoP2B0o2N8?N+|EIO=N~``yCMalE9K%ZUx&YKVPLAy?%TWk($}Vc;wp+6B(R>0rYxcCvMl*dZJc<{z0^|M`q5=bmAqY*z&q78miz z2j#f&x{8m`P*#p=tM(23v`W~49KdN%lsSpeDsQNnAaI+8D}g)G%A~h>Ra5r4Hz)Bs%u)l=T$qjMf4 z2Irx???)idJ|`>m=q=@qBKT$yEdCHn5e=a6d8c;FBI1MA_EZa`#>tdJJ?A7{L60J= zle&uAm8yyul892L0t&+|w3v#KYd#YC7_{O%T2mbPr&&ySM%PMatAx)1jpVAzn47~H z1^`{sGeIU}Z%Dw8$u@RB-FS@L>=__k%`(Y94dM56{BuV-d!ex&3rcAx%U{EK&F#pS zC;ijo?VFMO>Q$fiB6n^@IvK)3=lvaTS88fR+mU1ngNwFQP2IwC4N*BaOM!1v%{!Z4+}pSg9$@_8x*lJK69Jmp8i21%c%z5bNmt*@>|v*R=C0}M^8J>PJGDu!*9ROY{*Aoa z^)6)i&|`UPox!JL`Mj4*7|ni3A4H4Axt|zV9Ss?5#Y_!XThm|LQ4z~kS^1Mu)B1*j z+x1k$9RK)uy9}wi1#V1V{0xWdTwUXF!?--rVqVOGxG{2BF5wjcE;+C}89@W=X-8N0hc}pX0n^dx& zWJRIMJWHpv11g*BZ*xw*0nuk{t__n_=BJGKFWyUb+WI@bH92cZXGgh-@fleDqYEgD zBD+J4;J&zbW69Hp>3Se8TwgrKF(T@tkqIBo1!%|IC{8^R$9&^iR+U=jw3ETMjG@F#K4ij>ouB>-80i5@aIfWgNd-A zzHspRdKg_Dzq|UnsiWISsLz_Zcg-Xwe&`_p70B)PRt6W*>B;%4{uRMC+S%LZ1P4)< zxEc)UfHer9VOv&A;;~zcW=G7sZYlM_T|8l5DPs`&IK&LMU{D%7OT>qB7IxzVE10ZV z1UpiL9=am#E;1Pd1E{ArYrXS8m_h<|bNvJ9FXej7dLxyYL^G~j*^;3ip5`&)411Tp zNGU>ji12mOOQI~zW-|zCPxW`!H9Xkol;ysAGWXt*;W=@DjQ>PHh*J@azl-DVx@PpQ z9kyd;eX>xK8}E{-o%=3nb@K*I4?E~BS-XPitb&}bc@icnzEyZN!2EduDzn7~((6rA zo-`m|3M%7PI&?`f&j73mpgbO(tBNYa7R@x7mF6?n5=KfbGof-PfFp5uNW7 z=w8OWP=>~t52v03eP+)+;eKV6>7(eDTHP%CSvPl=_QU=C{ii0i!|+)uYbG^6;=Q$m zK>fGDvzp`3Zw#*+u8!||8RYOOrxPcC@xY6Y)b!UG5^wx%+}Ete*5e$fhhp4rH=&#L zTtR!t2S?0m+BRuz3!Zegfk)|rWu|AJpyh!zFPjDYaMLEu{Q=Zef%oDk0$7pT- zn4h9;;!yP0I`lPGeK&=mW4g7{Tjz@{1)gE5p|6V6t3$CUK;)c2t5#Xaj70{4&ehbZWuXAI{eB zkp5tX_xlA-1RjE+G8`q4@KlG%A)hiN7VsF4tiU5dG?`0`j>HE3r5HTAY7%pZE*UqZ z3L(?+Ghvg}clpQ)e>ul<)Q{T@IAIOKJFP-xM_RKnN?c2@?eRN3Dxlz!B8S!iKd|SV z*lzHQ5Kz}AmMdLp(2UpfKr_xRIyLU~E+HV zQ6z_JcJbm(zjtpn5`Tb$us(jU?Wz?1Y*bsd!FImb(An2U+k5su$hH3FdsviZHx&#i z__tDP@APOec#u|M(EoRu)O9&kv^%uk8?TWuU(rCo^J<%$%=oiWAT~YRX<#}bO1aG~ z-(r(q2=Gpen4X6;0j7)&?#{Nr{`}TT(#TWl(X>7r_k`A3RCOQgvDJiC{rk7qaJ2X- z>O&VAgdZVAEkp{dAfkc-MgQI2w>GzNBnf`^ujqj`!hjcor3vrF26yOPnxbWnMN&&r z_Bb>Opb0d|wm<+z1EP68|M#nWRb{;zASKzJ-9gM)BvF<1%F4=jBC&dp@I-T>U75bt zRDrx?2WTJ>BE=~MQn^y_V2_DU?M%%fP$5WHkDQoZLiW#1!E03 z4S=~|TX#XXtDwD`vHf#oc}gxl==e~vm!aFdY@|@2G2&=^`XnXd=!leTUpl% zM}{K-JH@1$8qNXG7zVMTuM?rbsmH$YqRE5U0(9|4ebNFSZ_ry%c*@(^{FVeAj4hG; zeMbedE?^d==Q-RH7QI=yW3xomH#gO!E|J-0x`pY0#^$2aV{RSK`)vFur?GRC439T57B?| zyrG6*)x*?g=8r@0?i~JP1H{@Ohwh5oK&FI@3bK8_9o!D;nL77`d^u5P=vKkUh{0<+ z>AfD<5V#Nm&UPCJVVGP9=yI|R6cwxI04&!t?5lcV9N-YMVXiO{8T{B)=wUqq`yk{y z6a!fmBTS4a;RD&U*In=N!6~{+7&_8zzJ(G{AkK(U+#kjagaIBKFDrTQiQ0WQt((>6 z`x?yto)?x~O5q=k*4`Za&3#mSFHX5*SS4JqmDrmb3K42-2&oxa5kizi186^jEzQW} ztl_#D#7F7P#?03H-8PTC@{^rkziCBIyDI}oX4lQobK5A4@X5A9k3@~g&bBj4Ow?aU z(o|*fmtQ!S0H{0HmQaIn-!I{YOWGBANa(;YEr1ISrZ36X3>rJs=FlrUpIz2h;L%VP zBW39xwcvGh9DPX=wHsvJ?4zgfuFkgEFO**f=?gq~_$_VlJgpHr^V>(?8QDUXN2;e*Lg%r`s;bKvbxI= z#q^SAJXCMcfpiU<&rZ9Bxm(qj6xO^V?9eN(r?P+|^lu zScP%F^ZDELg4~80jMlxgS~vn(qbV6wgctr63FB(LlzVv85_!uYEikhNb5!*h4Hhu0 z3WkKwdI2{J$Ga=8YqX<~CbJPGp>kzH+tpjwCYo7OOfh)((i}9(VCh}OTNyBkP0GHn)~IOvj)d3|jW<-%!$rAFH|3-XG6&$4U^g9QjW;v%Wau#F z3a*Zgns4e!#d$!zk(A?0*K=;PiU+hbYcM@P5WWk+Q^EH_y0B0Wo#M^%t&nEkUlva= zZn)uCZVZtEPGyX4IOXL^Y5w5j=(JehS95bU{FK1ferF>l>_P^~w-Y%dd6;(Pi3+$i z3M;bVN3NeG@X*!%l4wM*ofR3j!3yB=Zj8@h{Snz>&BA=qIi2!Zam09FRC0CMK5 zP%=}7I0Vm!%o3VfN=qVk!1Xa6@g9L zgJ<^a3TB9}V;bBNoZn}4bDd(aG#hak@Or_LK)%o$>R%oXDYb6zo+Gto7Vq=M0~GgW zf7*EB2|xIssEt5@gtIKk)YUO0m`m?UZ~bNl;f#VuGNz5Ojo(f;8yLi}Ka^DfyZUdkq#fdoP%`rTr94QWPl7Z{rD2|pP-J%4 z0`Ster|{-?<5XZCn5mXc-NLX${MyH2UhHKw-`bxuO*CQA&4(#M@*(*Ypo-o-0LQG& zuUn{56eE?3BiAXEZP*T+$zU)DJD)r|DNfN^=?bb8U~Gc)0c6Ichav_UfWiS<9`0iI zKIya_3WPv{{MlaPZJ4%a>zkW9tghsRBg=5&U<=j(Dg*{}iqLJZ*UDumizmxVddLh!fF zi*Pq$i&|IBCVTj+-pEp{DcALEwfp#Kx`!TE+eNz2tGDH5F3WL+#U(tK646Q|oJsVC z42D9xr^m|R<4oM<5uN)R3@4j9QCrz9?-)&s~g3OYceHN6v!fnymw)3+F~ zOBWIh+xWOCFG<&;VF9}%nyLBh@ub51Fbr=QbOm0< z4eVfDav&$=I<&mh;NzGj+KmulICMf9jI9o~lNF-35cA?rJ1qGitFHudamsb#cgV?j zCml0g;i7gd$hTFsP+1^FY1P1P~m6A(%1rq=49F_WJZCE|X;9(iL$yjRTGBY05!d#NeqDELU z9IR?*ypdybWS?87DSzTt1lND7dpU^qyB8ykj!^f$2v`A^V^d*LNJ^3jB~;Q?I}PHZ zaOplO^<&MRSi+Wa?Uv!0t0Qmx(!$iX@#Hvg*D3pmjc;R6m zWt=I2)+!w4O@7 zcjF@NUYchMecHL=fE4gfDU<>GLN^GHvv>M-mpnO46uAn#s6sshBoJd+ew2JwY!wgO zVwJbrA`M4taS-`L`5Sz9RCHn4+=B2#O(w#a9&OXrnsR?96dn}6(9CV{qK6eDI(%$J zhyM}ir3&HJq%#|vrsl`E8NHOYkQzbm=Qx`o?o9Hjh7V`$(Y5Kj0<#y78tDp&!UZ5h zJtKU0%Ti283mXvgj!LsU|A5UodT6h_`0ahD*nL09-wbmKZvJK?1hHy39;GD z!q(9V)6S4BaERoti~sF=cY85Qore_Yg!e3w8iexQ+f3T+j~}l`5H%LE>`TY58~wA* z;mb6wKUi8|+q{6*8ME+udU|r|=udD@+L(F1DlI%Q1IyStqK;}X?|3EnCw0i7W>Ly- z%}5;%2#U(oU2na^lVc-qA3t*=IwyCZ|KP|q;8mR#n_|*37In%Hj43x+M~@j&$Xmne z=_QMBX`U+c*j1O8pb|0!-|W?uV>?`)orqCZL+(+%Cr|MYhf^AENO&RUr)gA_O+9xH zC$=`?(+d`AE zoq#JJFVG-I(A$M!__W*YA~TP;0a^D@#W_esUujOF7xNpol$LVZwAy8qXuhK&dc{!* z>HDieTL4Zy(#L%7dxt-Mj_{7G*y-6mD9ZI{pu+>@6>c!1@7V(^G2r zC4h%-l^(k9z(R6W^S?k%OGswn&_7*|!MS+i_ftXYOHFngOQLr>2r+9F!GBwI{4GQq-d5#Xb0r|{V!{yU z%l_rf#iT55m;Ep{cvtO)zu_&>-i0TowzrL9hP93Wm2wHS6a2jig|7=oP-j;*Hw&ND z$i|P4;C76XS3j+&V;pAK-Ry|Tlw>yxyQQVOYb0s(3$iJkY@Ju-)vEPrsa0V;6ScuD%`<*@nVk_c2YoDhp_ zra*PXZSEwGiZk>UupWaK)g{D@a*jngKu)vB2!!RYXki*w7{{NQK@`nNAln)Q5b_l8 zx6&u4;_x1*Pj*UV4eTNMWTiybpj!3G_q7fMW+9n@I|)%Jx?R1Yq;LW&qI(x9zQ4MiOmQ_kNO(&h72ji~S0rO7ykBw=K28r?OBz?oC;`@W^O)U%19T%5cN{QM z?4zfi_wIG=o45J?i+U-0rMF;Uc#~KWr|nRzSZG(iW!QhyScbZ~aI_xvEC z^t4l_A6`G*c@rSTH)x=4IEd4?$xtnalBz)%?u3XX_=nCY+iU=fM`K<_#U-dPS2&YF zc74lw(d~i4PNziOaPYFFc$=MD5I9xB90NrtD+y`D4}!IHlphIiE$}4fo*wKU9DaA; zypUBTi`WHOk`p?Y;`XoKy9oD%>jVR&iY<_DWG$ck6eUJXX%V>fPtnY78%` zB=c2JFw?6`fNe+?5+W{tB5otlO7r6JB3c3p{I%vYL}n17&raHs|1q!$dQlTO)c`_Y=0qw%p0y2%#r{lK0{B2#jc~-(tkDMG&$4 ziRK|C6 zE-Kx-IZY|itY`t#tUwaYjh$X;u*@(%$Q9BEM)AbF;Jp7SH0O5ueWYeM4@htADUO1G zEmQ~@re8RcKm+**(ktVZAV)^23IRGiD!`keYa911nt8z$!P1rz1I|`ZGjp@C`T(Ss zdU0soY>L(GyjZN43ov|#b9&m7gD$#oc~E!T{AUkS!9dm=tRGmU@U zwxR;?FQGNvGiO&{tXDMRC(*f)jmv*MB7H_Te4-NSgf2LE!*_j;dGX~>DhE$oQwRXY z5f~YDMf8vX!ODQcEl4m^{stOKVD2m@r0R!Gr!Gms?R?q5A(20qZnz@aCk+!=z4RJ)u6rNYH^baR?z~x$XiZ;*Bvd&yH!_EB;@6-dcC!G zcpMA^L-B3wknsA+n=nd`j0T)g)>Q?APppgo*w-cIon(UK?CG&`UdFMXtVDpEvo8{m zBSccLVfcX$8bkAn*gyugu@rOVlxLr~nUTlBzyoU?+t7CbLu8@eHmTrH0Hi|OP@HVJw295k z=M}lm^*wOx0HJy+t4)nYR4Q6Mr%JH`%+?GLagm%v{GhGTkMC|qUKZg}WbS@Z4Xy@O ztuqpLThn9kqnX}dTF$NpbhtrPv4jHk@~7D*FAFSvxeG1dORK0o8#XGzV%LUYJF(?J^pWdTeIeFRA2bU#?yuiq;G4S=K{uz$` zx?F})XGk_JHHca3ZM;y2|5M#ihgzZY^N)MgRLgISu)Mck&7p$xv`Xfh!GFOBZ|(^M@L$iKjMl7o1kg~q=3P{ z(a*Mr`de2`+VIx~tp|KoBo97iApyGC@&>zlsoYL*jOR#YArKYD7L!PaNg1f@`Y$Tc zpDz`@f7$Vr#=H|bwK0F~DUN}flG>Ca|9KRpLWaM~UFMAmOK0RXz;?htv<~<|%3pK6 z-rWrANtjsb!)SsHsuB1$)g!mCIy0_)V2H@@@YyGsw)DeB^&!*j*J=7ZFi8340kXA` z;_?8K%7}FH*eB}4NbOsA#zmPb&Rq%`Ow1uyNGBx(j)YLc6WRtMw0&T;!K>qMk0C$B zCcS}7E3F#v5a&139;l-Tt+%WtFk|e0yZ8Jci7%Q8^77h+R9c^AYFk3kWd|^%!qUPX$sF+mU=W=&KPVotfWC%0w^AQjn_8 zt5@ljGjgv4>UH@I5#N4Szap?%cDZ;u*6k)=jGoTPB{xsKAJ&96{Rz`_`V(zQ zn#kYJtS1974fd`*B&(Qc*)2uxk$kI&1;kW0$E~_sLUC|{D%?b&it0Vp&W+5AI;&|YO3cC3Y2&#tRwO|dZn9H&|# zo7Q1*i(RnnyG@-4=n8W79P)`=PODTQ!AnedeYKt%whR-X2p^cCi@fo>YX5>VJNURt z>G>gukYr5TExLqbuix}GzQy-7j=5H_hldB>Z|BC7rSg&R-L|yXPu}!?2WOuTC=0h6 z(=Z>68Y_G241*0K-Io4qe@wULJoYBE$FN*;`hEHF>CPL22yGl>#@d)ctfz0R{qc8U zI<-pU9;ljO)RJ7?!rr_+k03tQo=zO)|&-ReTqlB>V)+FIjAjm@jq} zbNXPv*mU2CxIsD0jnb_1Rqtcb1|uYALg^@gJK%e~sK;DXBRet0?n$|!Oyk^Fo{EFF zEvMoTUXx0O8tMjRvay0?Ng?fvx%tiTovYglb3ie-^4*bUDN7Zp`dJZy7Nrz0u{tkd zX>$vxMid&$kX7kv296mC%p#ZI=vzY_4cc1NVr(}q!(OS}1_Olnk^lx%$xTA5@`i+D z_nnN)WbRGojA3m%576e=as5#VMbHrpEhj9FFrU_E(SfJ2Nq|U+4cPUFBXTLIja63| z{6;s?pA$JrnjfK1E2VyN$KoArpidlGb<|{jZiediMu@kbmkU9$EfqM`%X-Q{QnO__ zsUFWSFD;Uz`lWVhXoKh!3L7>Y)al~m~*;A9UN&*r%)vmPl;GBSy=WX4nvuvPXQ-?D7i3K3u zhwNeQh@N5z9BH#63g8#zto_ZwFCd%Xl8V%^dS6N&$W0T}@ZZN6Gum_YZ&{Hx4?9A$ z1Ow%6wdNhXi>H5yoo3X1bU%hQY4f-Ut_s;N!>wpR=|D<4p6jX(x7AW1Tr)Q;9m0c^ z5$*}(By48mu1tKTe$0wiX!_wQTR z5c6vYZ;ydu16U%erfyfzXcVllCP=ot~sGcO69X`8H zJct*mth6q=qsk?SFx8_L5aV<)CU8SKyE8)N&48ny&g+}%sGd&iCN5;3DfdAypH#$2 zq+xdxL`z=cbtdgKMDpB&quR`hSKDr7eXt6`J)fj}j2GKDblDH@mUmxm4 zpkvRIE@uU}fO9inVdO6-Q{2Njm$%70US@!Y3gWFo^9JTEw719ns>f-E z7`x>>+6rbQ$0bfQzQXNA>x~RQt<1?8_I$ox8o-bcjbB!{$)ZfMDRL7UhzucU<>tK; zbreG#aFBN-vzwdlzxHW}X5!1UZLF#0`7(CKEzfng?Z|^I$kIaar-H``m@4&@>~_ZM zW;MTohpQ8mb;IogWxeT(NmG#_Lr?YEP(vuIM9#&HVYpm_YSyd`JLhmwvR1R>^-VcJ zPe{1=*ufelK^Z%($17K5RrzPb2P`}@N&$jy;SK|yV^}@!5fhtD4!OzutMT zXZ25z2!-l>KtjPu-K^@FSv~+D&ogmKA}0vX8w);Nv>DE5-ROmkPXqQU9lUyys%gI$ zc}5v#8-^h>X6Z5g$O+kck=sQti(k)XFCgH1_D>HCt2Pu*pw?cIuMe^~zojP+vZB_r zDstp$^WBMKPLlj(V7c1BZlX7X*I)Sss>$|zih!9>m>gg$Jo|p}*~D;oGqATcIii4< z1p6cIL)6R2ff%82$UzxJm&moHg!j70O_1WnMFQ~K;w==1=Ia+^dXcf*JOJ$>J0Ns$ zfAza-`=>Y3==Ib7pWlqwbI+jHWIe_NGb8K>zCY&cWmKy>GZf7modsmS{?lX~Mgr4%>fVCH8GMhHm@B%ly~fkLf_33uoSffav<3Qlb@3 zcxV}~RI%1G)1#r9h%0sLJmsJ*+^hFbTKXK2hsONPz;E*9D87)vM0Y^U83Fq?;zdL? z7Nk&_ccfVIu}Kzu?3$79(7SeeNBjg8K5H_MSB7H_4>PQ%x~=#N4w=E4BOM?ywfZmV z@Lyf_Ecz1H<%*`Iwk8m*GDeR~*))%Rk9g|@YG(NFd!y?AB3#T#wdpV`un=PvRa7SM zt|#Q*dkcrLjP^eq$1Uz+12wm|$s3tuz=#bKs)CcAaTdJP4yeF+EmQ3pRzUEcJf(Tvu z#oMY{WZe6FMl(e{hU%)c& z5rlk4M=n@Dq2)*|cLzWTK}q2O1)BJqXATz;2OQ>BRJGdR<8<9XZFs*x&NjO@2q(I+ zhc{~~UT+#FwQmW+)>f@2w8}9733bfLt@Os<+ARl%7^DE^eOP?Ip3GfKdE{(yFEJuw z^9jVT({HT0hpHS*U+Y`^dvjF2rKgq@61+;&*+pui*@6QHwGFw_KFp{A@QSo^4i2uu ztn=csDDFRM5%+OcL{hyeRR+EfHk7OJfIDbPN6N}(#h1LEy_>(4iI8&w zTC3}u8pm1EH>>%)FtX(xuhhNAIx#`UryV|8EkhtkFOyg}JiB^>yO+zV>J>qxVla?8 zvy>JsOvwTmfcZ)3{NI&E7+#bsP^Nid7Cgm=XZzvG2)S}hfHR5_3?sZ8|vS`X!Ii_mUK1UpH%`O-Rf3a~pq z=HIN#@?|}R1YA^I5@QGQ8mrD#wfg&KU(e?qgBy2B+i-Me{@90~)W2Vql-D96;TM?v z*|fm;SD7DR;Z3<3Uz>TAS0!Kv8GdG~8tPBiSF_TLSNZmz^Eq`Q$BdN}3c=tFxf)>G z?9eaA=F9HiKkLN18)b`?MK^9rpnm zroHzdL85RdJ2&x0&4EiODe)_{80hZs6@0a&H)eHHAw%!{W5@Pp2Wp5BmN*xHG_q=~ zT0)@1b@5mUb{dWut`MavPT%Oj1u4!joYrIO>#c9Lz0^z<{tGXwswV?vVp^_e<7=yu z3Al--Pman@;1OAxF|Kzzoj%=q`x_9V`C`@VzJB9BpOd;8&S!?juWi(ZgwG5c^ER$m ziI@xB!U$JPuO_9T-Lv}tRzvxo=@wE!nhN|FJm74_1a19dIS0mWc(t6b7bzz41@1%; z;UxwZ zN0x%_w=e?$Wd*kZ2YW^#i@*>QEyo8O1|X|*L{<$NP~k?jLz)@WATjpM7VtZry6a>Fy91+IbiJYAjUk!?6~4~KsRMrz_&Jm+QsK!8v4!!bfLA zx|@DLw=19V!VNb2VvuNRxb{8^iy)j`R1UF_+kq0DK7rv(DzHKG_0_duaohp1t3ml7 z_RW7Hb+!NEC!?^?4k??KJ!NzL!$ilT?nD;O21PV4uHofgI#N>zra}NySCAw~Eq`FGzbRz}*~0)%QEPCt`yNE5-`;`^#JyQYl-kBS#_gRxRw z3`*QJb6_&f`%_~&HpmGGln)aMDoTjYkmrRD`4Pd*cRs1m$ygj2tsT1|i^38$j41sY z>2m6&Jf>7|cUG@CF4w?O1BOkia!M}T1AmQESix>_3M&aMyLB@ZW0Ps(*KpJgKsp$D ztt=_0V1|U{+6C_fmqOcPlsbS7Ka@bjLShmCL<-W~L59ZltlJxCno(X3(HnJ_-5mxH zgk{hP?fM!G`wp9-lBfA!fNq3R1jT`)f4V$OU;cb{vj6Sy?ELg#??`VexD*;}8S1LM^W(xe`@oK$bKA~Q``#=5S&;DP&@H)U2$<6YIHG2p@fjT%D zE&}en6=aVXx;qZg^rog-GbzEQPhHc{=>v4$fyc&817%4%d+a(x?lMc8_j_e?E!)s8 z_vUm`^{qSNMe)m%Uw3|a`fJZG7L(X*4+wjY*!$Zx6HG+dMR7gQu3z~1IgZM^i%N1j zLg+$6u?q(OjWoVq6zj!*`nbw%cg<~rAG%a2lpNO074e`y;}Qbq+9mEdb7{neZ$V}) zzn%mm!A+JxfRn9%{0V9y(QsBxOddC^pG>$rScqny9Km~LRjQ~(9Ra8jkyX(UKWVj~ zq_8!&6n(+-tC8FBumv5FggOb<#Xv`>-?YDA4B!kzT=PTMOKy*v>FM>nS!G3rYe1d{Rc$DyxOY*c62<-%h4&%^J^*Nk2Ot4IDu^~36nh#}C_d>_bZm}? z$%&E#!pOmzftgh!4e-OndOfIy#^Z142~GZc@`A_CYa1VMI(FybW?9UF_XbbRY;Dp9 z#d{3sV7vd8B!Fr6xy~8Aq7Em_=^H}Z8AMHm=-z+_*`xvvDe-FhFH)};PMVt1v7-Bb zI^9E`n}3Cy?c$UqFss`UxK~EZlDTgZl&h zlXgV3GlQ(q`2S?HjZ8LP$ss7LK_UA|1VlVSTUP@gs=1Xq$^z@t$codwX+MW@Px{3$ z&4kc=1O6kNYs-3Z+jzD4qG!QuAY~5sNgsyocaPEAm!Kqs%=I1IF3Sb)-G-@$ z2O#@iPpTmFz?<)Ev388H+p(MSZ3SxfZfAMXNq1FmKp1yzzsb)znlH{IJ<=w^X%}<3 z9GQbz76CpHP^lxba$J1=F8^E=cjdTj>5(%=tTwM zO};m|Y_@dWLLS(PR(TH~vfoU&Z%ed!Q8F8SG6DXk_aiM8FS-(V8Yy$oqV%Lb^8cg_ZPtr4=v* z-H1CHmi#w;4`!fmX4_FD0^#=@|@ z)K}q-E#}Q?>w6m-Zq*}4iZ`#yK-_A?*4nC3NvDAz2B8b>hFjmCKUuyZdM(A0ikx!d z$$eN>oXCi(D?qFl)v_D!0YAG6+ zPj|fFp#A-|&YCW@>Lt@}`o3mAAX^%GA>tJ^j6?(3EVa!QGh{}DCsx+sy9%#FOv_y6 zj#$Hmn2m}8a3}2cdarjLKYbH{m|y%zHQF8wVRNYKDQ&{U&KD}zSWgD1LHEnkl=coj ztij`+eV%VGe$NbM^IAB>7#u7ihYJ0506`85vxr0<75SvAy<@$5-e+6Kd+)^yb`sG^ zPtssH4+x`x(@_|_q1??p#l9j;uU(<@;Y`@H`KM zp%-maN{i8nVU$9CJN;Z1l=jR3k5r+X865glB|{8K*tIW5t6o0ZVq7?G7k531h3CW< z@Ld#jbT5rVSfS7W480D(t~tAsfc;1r7u`R|6pzWCSwiAOl#(8S^gLu}Ls(wtq}&t? z83{@O!Hor}001XLr6v=rEGM>+AD;|M8Gj{=sP_$XY5qXr#>N_4rJYNQJpmKsX>cPI zo_n4MUF4CX?Fv)(s^=|FK=cCTv7R|jeG8ohD`Wz8WlK&4c|Pkgh9zg|PuojN_VCw+IJfo=$l(<9DHI`@j3eCviZLX zQkGa%4n9UPt%^P!c45U#{y{+f83>`BNYCo;o5D|ZdW4@LUfR=s?Jxnf5BEeKb^(Hp ze^Qy_{LLeG*hJpQ&K*?0gx*O_bT5X|3`%p2%@xMq{00x0K5rv|wmOHfphBXoVbEk$ zQ2m4kw&3TwQf2C_BXTA~z#ewFqvw-Po2G#vFs>WQGT0=P7@u)jj+Otiz_=z%lr*Uf z2MqTUs*g7!J!r|H^z@ZL8M(1W>@$Otk_3_%0v?2YV`vh@NB08PV+#udVtB^N`vj{B zM!V>(wpm{QQ^bMglX|>zkSv0YkkS|`GwK&3;G*UC9{j;ppuqice`1~3sPEQ}#1_7U zMj#m>i5P91al`uU$Djh=Q06nN+S%qJitF7ie)-k^^Q`)LMIz}Fcimxy_8G!B4#O3U zt07lS_pJnb6DYgL7$bp3uqlxPyUvSYo0pS`-^+Hg**_`QuaR6KHG)Bey#tVJTa>O_ zwr$(CZQHhX*|zPfUADc;wr$(SF5NnPUq|25FFHCSXGTUw##)&fbB#H2{A2!KimBnO z17FP?LId~Rcuw0#l_SGMWwXGm7)2e_3v(SQ%qI=EqFT)-q*jk_6lw6NFQ$? zOkBpCzx}pobBEjUI8+4Wm?_f8(rPo1G9;-N89^G^RHoE~-L<#M8*Mg7uGB(=x?nRyOr9345_k0gBiQh|0ZW2R3k*Y9MG; zGNBwOLs11J*?_Xb5+?Y7*Z**Ff4sH38DaP^D)pzQQ|{T@ znC%E{9t{%0#>Mw<`Tuk*wAj|tBtqBz?=bLYZ0jFW`TGuEQ2F*68{5=>WAx=$l%C)-mVixa)jnWPhW3|D1+Xf;&SYd-D!U-R_YYk5V zf{b-{{6lS|dFKI%&bU;XDGBQ_Z>T^`g0q0xWJ{(d!IlYv5&==k1*7TVv+$E6-z~Bi z7Ghr{^S;Pp0#^iR)>)bpAq~OoMwSvn&4$$-F_zEhiqV!aUhT2Qqil&$rq6gmo4#QR z;;~GVK6SE1Op!^#kZv`vc(%Qd<;59LR;hlHRtIH}b^^^@pq}X8(LQ&rEZYD(sVeF+ z#LI-f5&_Mx@3AF-x}lT(BUQp0*W6`czT*~f+)3%t0ZVHhVLZIN+|5;VTfpT_(sY(k ziNQ2n@zhYq`GRI^yNnks5#cA*FK>LNx#EEqX+@lCBLTmw($T>jh`@kkxqIKB^GC7XnyA%asy zcgbbiz5%rJJLCwdXz(1x7LxIDT{)=ufR(_b72lGK_Xe&3)+2tS8Z^_&_cC(b%XcGw z0j33>w3uqvde(Yv%idJ8V@!>KdU(2ioxJJKd~tj$3xN?sdK3uwjp+u>JsGg1_Wk!)6X@A~XfItV`jb_P=AiHt-neV^0fD-+&_q^%M;P$za zldXentX{)8yaA?t@w`Pt>~{UbQ>(Y5p?>XhRyTG;xRqT=TH1iEy9=k?$~5Hi zJt%8ifmU}9v*d^KX{POf<#rzX_h@>t#dCoYr&9yK(EM75*G@3IugV+F_M(JcF>bL% ztaTvK+D1D`@B8<5RN??!H1p-uxNs1u-R9fGU5ZwO|M+UhFMB%OpHg~@UK}{-2iziy zTWQx@+u{NkM;9>F=C-}`+ltN_GNx#sd0Rlb?Vp9Sbmea&?UgRhGO23s5ayld_iLL_ zdGN)B4_7Xr2s365D7r6b-A8t&Ab{lecOGgs_xm<~<#-Olb|@4g^}qnN3Q?=mzx>4T zIQHZ8M3L7|Bkb(Fue^v+?wQDK*!3lKKqCazc@)n45kwF=8A~?-RD%H!&FJk<)DPH+YukiXt6=Ke&{YWug?cH&?$6csDEs#a4lHx5 zg6q<*9h8D>5Q#e5p$Z@ZvHyDRc9#S46VCc}q5Ra{&H zj8+4q8m>*JTQhe6GF*j8oD^)Ykz7>5Ji+aG1ee8D(`3DwHRJV}CjJC`sNdjVj&PEu zV|EehBZ&zC^}=!7>F@wMo?clz^APT^*yvu-lOaKG6&n=J?Bu2VtNr+b3zXDA)j3;* zO7u!twyWsNJns4O$QxNBN;x1Lot9Dy3N=wnnD!MqW8}MIPT6g zdFzLiuub~I`XP30rRvlIX?48V1S+XM*WoN_A&Hr-qSuA0c7&_=0PjKB2R(W4HH~fa z7-*{AS?-!+YXx9di_yXCew=MYjk=#Kw~4rs-MsWniL8QC7WVjlGo-GwNUoI>6;djy z^}?}B`YaTCwYTkheVj+fOlSI*uswRCg_<6x~dNaRB9mro#xsh=8JeMO6~Zs=zm1ONWm0saUBEy3%jK)`@G#eFTdLWsBTj zXmvaSFVU^sL80!_%aR!PbUC#nvmj2k=niC=Ru;O*V;Tm-kWGD-lH%7F*Z%S%Pzmjw zw&=KIYl{4{$Wg&<4M&#yx?z1UD$6xw z4ex--$tBQsHT;b`XjQdaD^aQiTSI0`te~O(wl7Y(BqBYBh;<-jTclj35mG;qKhc;b zEsqLdb0v}lFO-vR%y$4#?59I;y|b*4!rKymyf#He#G0&w{4 zNfaMZZ3HS3B49!Ryvf3PyAfj7M9bjE0KRuQkv}wm6a=9qwnOb5 z{fhJDJK+wcq+!%XM?Vfi=`69!P#V;DLMysU?I|8I7wOTQ4)u`089XHFKgVhGE0pLP3B|ZZIqI%%x7E8|-mJHb@Bs&rhVG%OV#gjTHUZyJkfsv*pnz@dTF04T;~qj@pWyaZh6JlrH~59riQ3`Xy~ zQ~^6DXB4T^;VxQ4wBYGWCbI&BA%Ka@&Azw8``R`%j#KSV{iU_ELf&R(qM$;Z`3qP| zH=9s9a}BBz-}9aUXQqR2Dezz#!U1v}NgMHK{c093)U_!m`{ zO2|&S0E;n=;P+LbN%Y)R1ZM9<F^r#I2CPPdY>`-T=2z;RW;X!^G{*;{ z;&1e_{g^&Q&6t1}`iV$dg)DHW=$Oww{%n!MQKS?VqNYTqrwlgioZHY5JnBjqEj zVr>FTmy0tCtobDlgDFRV@woxk5~(EoM=_!Oe*9R(_|h;D;cfw1t-3`u&UZ8+jtPN? zIVLvG6dMq=*(7RsjTg&a*vVBMJ{Nz#(g@=9RSNshUGlabAF~Jyn0V7ykASL> zB13n=kepo3Ujk_M*_nxh$4Hh4@3@j{1gI;_TdhL^DL4_ zLzX@KpJhuv)_UAh2q=-}X3@P$;CAi)J^b$$0B=4Z7A|OScy2z)B0Kz7zEyf^jpE># zC_XkLlXXVp1ID1Rd1pgF2wUSv=9PrMNFVdc#Tg{M{eJJKQRdymqvDUo+n4V2rzp=>Ok<2{MMVB@$7E@r)r3%VzJofZ+je+1%JPA|3@RXnWRppCjks^*ip z{Z06g?n0}D1<_i?bse1idhIP4@@fW>9cw7zgv}yJ+%PHY!&>Y7DD|#JK@eHy6OV^c zY6-}jP@-QuQAltX<&x5N1Z-Yb)TCN6FEk>b=>hs9Le00|A0C_Fh)gDJ(OYRT)YP;h zLI^rHS39Wgjzm{TCvW>2fW-v#C0cU0h;-MNDf1BiqhEZ+Uj`t??xG+|$p5bETt5*! zy>&Z%^;K^=)0*->`868N*r0JcQ|*9EL7QA#(Kr;S?vSXemXVxBb(NTV>iMU2pxnYx zK4ZZom!dm>GO1|$tT_gx*g0%)AsijJ zLTFOY8WRf7C@>b7KBJF8Ut8Wg59OpoK!C#>;9a{n+RD-Poe;I+H(iF*iZe^{2{Zv# zs79`p7R^$ecz%7J%3u4tKepM2sMG!Qag3<5C(wq$D4W_Y>BZDp?3-Re(v5Y~|Zk3j> zIfGX12`On=R_&yJD9)(e;V8gK`IeByh4(A%3b{!(=ASiE)7Mg6qOjajU({$Lh`=w< zZbc%2he#jC3vv^j^hpj^hl=2ULr?N&JoTuJ_po%Bw@eyN4)ql;WXZow-f=;YN#UJ6 z>Z-3OsM>m}1`tjixG*YS)z|(eh|%=O&(G|eL>?cIM-B0H=~d?5n8&W{Imi80L4P=b zxAA`kq~qWhMmD|sKClWug2;AT5(2;1c3B(Rr;zlD(K13$Q`eFCP~qJv8S~MB4*B>z zLvqZD^P$!LrA1VOsT8%Fk3mDS0*W6X=x_&;jc=WQPvQ46E{kEf(rI|^#{h8mF@rH< z3gy2g7>-}RLTY0w2*Qw??~0Ihk9mEBD_U7ix^9ezXqpAbIJA-j5EDuUR@j`WAS=TG zO8`jsVJ$Ri>u*S>GKw;hi(h@~w0LY!Dt|6a?c`u{@4&P4dsII9R@dR^KwVeffVZ4X zme}lL=*9HufOEUPT?*v_9hC~s7fGoz)N)Bx!RPY?b?Ptt%)%fmJ$HN}T77=~1(L;; zjd;tiOiFIBk(MX@|H38w!@Ly!gKygm3IG6s2mk>8pM?h}OYXRjbR!ZLlG9 zU8p0BiE43d<9SlRyv`EI)(d0;$T~s+(bFojH6oQKs<3^%_y{MUm~xMyRnw(PPOoIK z56n3|?;nF!x)3iJ4)AsQJ?*1wQReb`J(t=I%Q3he#=?sECB)Q4GgH3p%1+ubaFYI1S`sDJ&UFzLypB_AG16{LTHjFGqifzhxA2I0I=5SG>N^)P;rN^ zGp;Ux@wcE~mt|Ql;KuA8I&b<`UxXR0W6SaT>8-SfO}AOMafKI9V9hn;@`vZxjJvA1 zG*e>Wt&u&wG4YDEAqFH18gy-F&NX-FHXA)4eHxTgF-2J9Q*3woJiK3DlWveL*f$}q zQk=(XPi%Y4N_BllI|Nzg-D0;kz=b!Q70rlFq1xa;2nD5SNdQ=AVhT2+9#j~{Gy&xC zdAjUJDNKXCR;LdEa9n*`W#EC?iVB;1*2tMdruNq`8E-bc_7l;8%;dnB6lTIjNPwFK1xIEi>GIHSDJToJ!@9eO6tQ>38Psdx#Wc+D?>oFA|xN^taE|gNyo|P3QPP6g;PW5mg*H;2NHM!sbM1*13 z#e>sje~5|lVthi<%t`J=urHaNdV&`dY!ZM1AdV(ORsi-k`CN=Rr{K9P&v;D~c&8sDvpH8ur#x1P+x^QO#l?Y# zhF(5yQ$hUlWsPk6X+}D-S7@#BjtE)KA{|Y_4m)5l)>H0#QW^UnmO_a84$)-J3gtLJ zZ|JmcSW+);gbhcYvi$=uZdY6#fwNbwb==^>5?*!bdnPnFR`QV({^roBe&ULxvLWf4 zGC&$=S4zP+8z;h+EirKVI8WEi)rbRPE^4|Nx~ZDEctj~puq`L2vWh=GdCFfWaQB4G zf61TYOJ{R7c(w?{hd)IRcgljuA7aTu#`Ma*%(`a%#;T@R;05Y2yFMCf^2(QpBdIpm z0QFBzYZauVZ&*#H3Y&=3Z{&hj>$;Y=zK4$~vfhC~lZ53<+JhTv2Oao*QgRfI5}j3D zB&Es!5<0Qxe4aM@{C!<0wFXTr@>RfzIjym|?o_B&o9L#qoYf`^F z9!{F2yEw#kwO#IrSTZ2Y;^o5MKmB9m&>40qW<}1#@ zaF>vLA0uPPm+t(ygGR-^?!If^J7!yOq_I8Sj+?sTjH|W5BXZj|< z8ThqdMr>RmQW{I#M*&?vRoP)9TWJH@Om~Ti{gYZy{lr;LY0bp@`u?&%al&noY~KN+ zl}GkNqslvQeQu5y4!Iqe`Ba_);+oJQ@H)ZB}e9Iwk2U!u_W+ z@K$U1{!k`Xu8Yh!J%_!wk&}z=yX+z52u^_vSk*O)?`e3ZbEDxERI*dkrig!KV{7Os zce8EMj>98WcTH#WG3{mOSl=-6^!UORQn$Tg9GV!!n9G}U&ZJ<@F7l}&62UE=`U~)% z=nn<(&+Q*r&VM}l{~n<=w%AD}W`d@AsI4P7aobaM3kGt=GRNegRhOZ*nm+W~T; z5lW{bg9tboPM%%M)CsRfzVgCbyqC@hJczn^1O|``g1gk?=)8{-?qY_xsBsDVIuN~Y zZnjv*r_N+h+5m1eYyA)i%BM;sO{rNY?9O1?!de89E@a3>>Wc-$_5S<@Cd5OqjY5B7 z%e82UL8x`YOmW<}wW}3(wh~My4|~^yvmLP(M5ueapJe74$j6XY20h(+U8liFY$_tm z_JV;{=88v?j+sZXBt#25B{(lYRbk@8aIt@nx7B>{Rb!OaeV|t$!sSe%jMEkA+o$I9 z4_hqOo5+s>i&7<_l#z0gW3yFCBN%Kr3KJKBG>C(W22v$M#1NGT5fK~zxUtEUX0#mc z#ST=1Izw>E6Lgalv31wV)Rs_loM%neg&Ot6R5%+JZnH?2i=Y!}7U&#z?b@>&%7XjC zZf%t?ts69MuaKVt#=OcZujVQ`A?zfo1TE)IvK5@LK&mq)!21Hp79c{~NMH#pMAW@e zDByxiMKy^Mr_juGzZkzFj~|YlmWH7V4U4@0@!cVMiN9NIMw>v8igPN{@~=x63MPe< zHC13_z*sc8z5c+=o10~U7k_L;RIDi@Zla3<<8t*X)xSeS<7H9sK!%;Kh>C4(R_|Gr zNtc9l;RhBXi*_;#r0LvyAHi6uHfUGN9Cjb*wR7p-k^>%?h{HTUOaVP3k+nq1sK_;b z?+{q?VwG9SA=ORS0Z{MWm^twRZRlCCP4c83|8v`Jmlf6?kGbh_&k;e!g67Ce2&s9T zQxEwy&d|kt0OEt1D4vTgh8j{L6%>FS2FP6xemPVEVgHb2Ieb^2Y~s_G>^dHC8;V_P zYr065JvRAAoWuQmxW%iQr`%>I;Nw7=79>;!HU1X{1ZP5 zUb(FJ?%0F@i_KX}RVOBhEe2Tqz$a_?olQ?xM*Hw;KghKncw0`w#~Ve>0r2pxa`nie z`7o!K1w3uvs!vCFT)0P=e|^KPXIl@*wzT_xL1VFBD|JY>^<8RWgUQ>yOR5+{>612- z_i)<`d#dU*azp2){kV0oZlVY?n)C!65yL;L;=uOK((a1Pb&JGp@5aWz_PH6`EbRMS zqjTwf8a+hsq4VZ;$EK5;x4U!Qi@yDtzQMwC1N%<8(Gh}DO1_#2H=>jTUo_ogiZMwF zA&)WH0U^DQ1fp}mC=b`rIf?WwOLGTgQy7K*SFJ_}dGz@ZaWJ^FB$3?j_&@xM=tOZ= zs)d1nclJ*{_rmy%7vYsz*jf2^ZEC`Lt_RKof_8{g?K42gIa+`eJ|8;j|Q)6<@!D@Dc2=)jL;^`Mm zDM6H5acyx4YnhwnQR^P96?PvR5~0KIFf(y7yBRJ6XBIAZ<8E)X^x~v&{Q~}X{4B(# zkohNktNg@Jynn}NBM)a&eJ5u}OFMJ@ek2=7C&qJ#uAmxhr>bC@d#QGECzrk}*pL{2SEl{F1#qZY3|3He3~QM@Ey zNj3y8h*$U?iy<3YLS?KL{T)-9NSF>u_d$-yozxrp)4&M7268(Cwa~&AOck+Pj+RcUsGtp>P?#+>QLXewa*x|T>3GOcjNKSXYd0jAKem2}i%>#a$05l&I1wf6Rp^i#DhcCmsDy8?qISqs_TEtM^A@Nohp%>=D z^SVOS!==qVC5E1^EG&9Z7Geuo?)*aW!^k{dgx^k&fiGEHQHrJG+=iOoTR6Eg8^ZIt z5UWHZxlHes;{_haxec~l$gta!mJXSVA1nmR_3lU3zuVIo+WNymer?1zp4H7Oxk;PB zktCI|R-EcyKSl}7-CO&+5n5i0T6<zLTk=tEuC^rqlm` z+q>eicG}=b-1(%gyI0L|uB;;=Z*)uHwa*#(g>;g*ifV1f#G0g}h(be{4?zd0HK)7f zw`0d50YpNoD3_hlg{6vz1i^y!c@5TU{Cv8DS1&Ow+TcF!o6h6+h`;LSoHd5ZLZ0cj zT%XQZFJ)R1m8zN^44sR+wk#_0sqRrFnmHCu4SE|XxU#79sP;zlr2WWrKxm}Eau_k2 zN$n%zurQSrCx`-_YcbTwor@a(E)K(Yay%B}!DLd~-*L+|*^j!yqb%&&H7A{+LJJIj zEPndq`;}QcAPe{a`pJ_f*-uPsa!&-`_!bYziW)I*h;&!piQAw*wb|0jz1J9ZV$R>0 zD^pv$Et{>+FR!Q9m7~|w;rZR@ahL~qILnmyOR{b%NgZ92tTF#BlDdI^iE8rEc{;#w z6VSg&sVS2p`al!IC(}q9@b;cG&e-FR)WNf0D#$iQ!PV1@yq>In&Zj$D7hhH)gfB7LzZ^nC1y4uI1!(12P1+Qn%-;b?V6Vq;S$`5sI zd)7gkVY#9}YyrAsWa3tHOROjg5vxbOSB(nNDyxd2$cc}a$cLWC3P8Iw>_H!)vKtK| z$gLZ5Qxh0CA$P=2ItSlh-M|-WP+^u_%@EHM9$1w;K6t~C@EaoYd=)XEo_Wjj!qSn( z&tIQj&zF&1U72+!>Nk2iGOLG?pD#xjD>nD{fdRVQJ#Bvs(`Ke-#)hUqJoPuzzu5V+ zf9vb;F#C&C+VRjTN~~G#G5npi?|V86XlG%lbEo;KVHoE60Br%-kGas#!p>3WQEvlo z>(MOqI$ti(8<$_bwBxo;VXk*%!!Zm++Xpp8?*CD2 zhEIoU_QaP=9N+9z@B{P)s#u~|Y60-e2_ti$cngNNlhP~a^78812BEZfMo9#Q zD-SR_n6ntP&_ooG#A22*f?!kzCG{U8LtS-ZX5rC52nFph8D0how>{Kh6ui`rc*CKf zw3ntxx2~5mWTXZgY1E6fkOY9iB9MVRNRQ1gcPpXZ6CpJvurzVX=%ryP1-fckv)ird z{^@Bq+2-vxE@ zLHWKj_x%*8olz~c#eRYD>h3u3Qc<=`}s9*v#!eliGVr_&p1fw`V88S^LOrzDx z&Zvu#RCYp>2)C}1i=IM6(_H8QDlDJS10u5B&NN1s0+c}8D6AG62S)2w0ajRDEw-F9 zT{v$7K8D;T^oHbRsC4NRuhHqS|IR&*QGnRHQ8BYNkYa|(J+S2wHeRVI!=RX%T$|Za zaoKJ8YvMErnQ|$C^KE=|u^^l8kd;wh#~bWD)uF)@*6ZZH59Y!^yG)?`nc}qu@{^r> z;2_q7!_2(Q*Ch($SS!%H@UfuWpEkWaCFk_DO>89AZl1$y$MTJ0n(g#iNa8SPTivBV zH24F@n?x&qI)Q0T+A>B)v;(^y=~sHbA^1eNV8K0EL3$g=T{6I*0GBOR=|(_Pp-J88 z3P-%*+y#S+dfCnMRa05Gbws@X=|lCL`R9 zQt*+@Nqd|zM%=rHZOBSZoD)tCYw;|rz53lE1Rc%^JHldebwq+a631O4Uabc%th)q&w^RefcCaUAaAWG0zy}YW%j@NM?4}Y>QQIC(3!?~7nFZP2&W1(}BU-HNF*G7R?^@Tynje#h!%M;- zHtgK=`EMbJtUh?}DYgZbND*-8nw|Tn)S{1c45N96=Ww6FI?NY|35sMnxYrNNgIMhT zW|nnZ7%b>V`QTRSz*d~yJ@?i1DV$>{B*GvI%)s;^DtB;3s-OG^dupbD*yL@GaGo<~w^nKVv~(8~yi--b8e5 zx`b-CkrSryV3mU`0jcny+@3>@a3VW;xdNgUZ(N8Sr56iIsmm}B@wY&irn>(hL3~`N zIZFnNP4S!UIYPIL50EDQiDvI#UmMhfsgb2o;c!YV6h$wa`Yuo^@^=vZUm&1y*ZRUM zzQ{m$*VJ(@H=OGMcT$X%ylY|n7g7%RgQJ@_FaEX*H@8|Z5r5;>4$s$6y@Ey-C;OPW zw1dT`$GwmcR03Iw2(MeFv;`Dh53 z2Yu}dqT@L?a+D;lGBAy(WUvmAW%N=lGrfzPQvLzDc|FM5Wp9$u?GL0)Nzm;J;KqaJ z?B6{{TIA~xE=3ICt^QlOUB{CaM44!LNj%$!#}U)~?*IMQnhCFKt#j?;Y~`@F%gkUl z{jq4Fimmq1QZ<)zy6(qvQIN`X6Na1$mcg)=k~aJsFxOwZ1cEvh0eOaEj5apLLt71` z3`f#s1&m(~FPM=0-QwXD+9GH~@v*&ZThIKS+0K$;T$w<8&#hmC^J(227N&WD0ye#| ziN`MaEaR}F{|gJk!#uUKo*@1hClqh z_&d-}s||+x>u_8sL~TjNe<@Vh6Oh%2 zTrY&Q|NOj~D%Y-r_zyoi!S?lYdJH$$*Gi)CK^yFfVARoVg04tFh&KC8OXCU#mZmQP z)H2##%O-NIjg8YlMZ?8WF2=7vH3DrpF>N=?=2+ZCHCqZyS*I8Gv-o9HM&fBeRF78i zsYdlIY=ARdN^i!)=tPqK7=3abdSlBwmr^&zUD$5p4DVf<`&(#(xz`;n$9-X&T2@1r zj{uQZvJ8CAfzRAJQv$UCA^de;#{gc)lP@l(3t}$n?2bOVjo)qCewR}1IlA3}<@)HE=*AW?trdlepzK0sG8HOxMFSCHRFTaBBk>iK;;{k(0dlTJ~3 zwJbIVgH~$uSD0Kke_gi4wknAF3AD>VMbSo3Hv8WJ9V`@m-Z6HB%BTb?GQTHp9Jn%I zq_>$_fCS+1!95~E?L9MT?kqaWX6Po5(Zp=5yf{&O9-Ppg0rS^nbDg@ zmL=^GRE2v!Z8u|eIf9`eon~9*m3USF3E9wvwzpqNS65&%3=Bj9xx9y7tjk7`^**;d zyp>a|uxa|7?9*f*%l1Ycpt9sQ`4a@l($b6uAK`mnxy6Cp^s1C$4G2TE+ zNyrlVmO2^D?kY^<2(JS|%%Qn!pu=o%84h)M^usKx4>r3KxAl?hIvS)g)*GUOq)n+n z&3pzZ0Dz?^xM!D}-h*trkrtwL?7x(0r>4>!0FAa0#~TL`yuu-f&kH3dtDY%OiV7?c zKh3bH1R1)7E`^k@Rc_Y^D`#t%Dm&M%PzMY0 zw>?$;N^`|7VY`5zbrYVbFCV!D(ONtvuMp3aNxa<@D72}QqF>>uomQ|RBmWPI*+*fy z(8@wex&EM8q!t0X0dy34k^=8R{n$d1-_V9vI^ViFt*z1~#7@c7s3;T+6_QW+s#84a^N7j;cM!D_JGQC>;b zMciX1#(N?M*QZ>qAmp0$3s~ot(MXiSJJDo-IJ-O;R+Dme7WRlm0L34%iAFs;DJ?Kt zinfmmT&4U!M!|5QnSTz{jlh!M^1&u zYzFIy4GRMY^?q_fOup?Lfq+5O@{OcCqt$CSe<-ARj@pU`~xWfgW9G+nu={2OoN8NcV{3ROX~CKJ7x0#Sn(e&Tb#Fa~Tbe z^M?u)iO;hZqfS8NP4$WcZ}*2ggGNyVS;q{S#HM~Afkjs@jK*fFJP7P+!K*f#u23cy zb8YHw)5hxWyyd$VGViQ0(mETm=Uld*CavLuY zJ?sv(KG}Qkh&;)XsKS)-P-`!)#LQ3k17|THJfhUj4KT?D(`M z(IJXW0){DQZ>@urTFttxL(NbzzyF))LSs|Y`~GvG^`im+VEy|-v$XqBTwH8_YUQRT z|BK+Vr6KL~(`MEEpl+}Uonq*5KE}*P(lSpaOC!IJGY<_$l*VM6XaF!8`Q;-3Kmv%6 zm?9<1y%cTAgVwRb7ogf_5yPaV8Xu9!hLhqVogu*Ky(M5mrcEFZ6Iu4VwPOcwAyrZE zNK%AD>7eY?y}YZ6DjyMNmANk07>Drfj~e&y#c5VsELyf-%gB$hJjh*%c^x_Kc`@4= znoM}4uzJXV+6@mT@fBJLmt)7z9b_``n2(!NkFj%T>8qrX0ceyf-B;HX)x*MU2E-;h zjeU^2fu&mr=Aj78qioS>GyZ$nnR3J}c9gdEi+%oL~My~tw z<=0Ic$g=J`xS3{CNXSXW;I+ax1t(yU0QfG|%amTQsnr|CkXuJOk5cQJMZ!-o_sTwFK(blIxXzY2W>p zBcu#4lP4$wz?_x^lYChQQo^6CJ~L*H*EwcV+7vRnziZu9kF|l#_$t?)Lb0^8g zwv!M32-C0D4DKB%3;&Y0c2*0yw4-v>s$A&ZsPffA!%gEc#FsPHcS%mF9P31^dP>6& z-LrW>8f{-o&XE~*l_hb-BdagHtf%B~ox9`UVLiHB{Qib{RSbu$HtIB*~4icv0a^quEp+LR< z-SSPY)_nL84Yyc6txWrY_pn+2sejJFTkYGn^*Hlw?=^Mv?{youJwPi02mqiA0ssK_ z-{Hi{(A7}i(9zM*<6oNF|FCe!sL9%|3m|m8t5Mv}54N}D}oocILJ<2bs^ z@C7k}qon7oLOoLuF6GKk=8Q8Bztyl53r%ZgqW)p3Na)F<@Qj~_`5yQ8ofEe+M zM+M?U0i{>z7OWPYd0+!-GK-VyJ7Z@IT9&uzp%pZC*N?FIHNuj~FPf)#nK>Yp(>)C8 zm2QU|jJrw}*DOF2sOgRUEV@1jZ>K@z_6ZxPbtKMd7Z94?pn^n;fgbym&!Ro?o9*y% zz9e(JP7`?+%(Yn&%Vl3QR|@%N?ql#k8P+*7U?~f^ajvdcU#!?s-wlf(iSIs(MOfp= zKNp|s;jtA2k%ov&0nDZF7U*^ukjya23+0R+4VvGolb`$ox}X)7#v`{DcPv)Q!yDih zoDotsr2`s`=FAxFdYD#%=l(n!tNH2uheJSKezI^!H|2F{fIfyZ3Zs zUnf&*uPHO+G+`xKQvVDT4m#?iQ>$DnoYgUM4g;P8+5X*cI6|$nqCRNx7RrB?kl>n1 z7t^$6yfQ@kdgb-{V1o;nJ4i*ZKlqxos#bgkxbj5!G48B0x9yy>&0ZnPY|^CHvi_;QGi~{w(+V|H(^pDh{NG zWFt}5e|pnyIsSjL@_)>4V?!Goqo3}ue@*dxLxq%R6C;0(6j_yrqgib;VG)UK zl<^o3KOo?braYjC@lNv`g!GoffJ>z7j|v8(B##Z`P~g_Z`gI=x{zbHt!3+5d#3Vfs zYQg`Vv5-y@iNAeZkbW^K=+ZoqjtQeVs#btPKIr!Z_!`jwou^xvNr-oeRZueylDw88>KC$LZaht+*?@Kg*M?)`+ZmmxpA9%sGSo!FeHyAM%xQg_L0lN86eZLwQs0+MSY zH&K651|1Qd>xD8W2ZuV~71thi3_iqOfN&%@Jg~nbhByGQTGl~y+sxwC0`9LKZv~^GjkJeWkQ7{ zh_?G3XmZ8@30ho@sJ!3u)78FsQL}X*#I)Mjp141hI&(+jgE>qwgPd+3T;*b)_YQSm z&&CWeuBVnLNw`{qS(59UBS>*Z2hc8kHMp72Wa{5u3H*T_`9Gu!5cBC z)?jgLqCMa;EaxWlN&)uG46f^0w5Sy^E&=(E2m}daSGY`ea zsIki91#-OgWjZ4_1s8Pfb54pS#uS0bW5wn9!X&X262HQMo=gSo@mVczPfPc1CBI7x zb(^aDwq*(56l*B=Ig7WuJZ+p4r}+qP}n z&aCv4wry0}wr$(isonjh-_v_vSO15I6>CO}Id0N>8}KhsY?ln07Ro>Gk}5YAR6hY+ zvdN-9ChP@oF_;X=W#zB3*pbQQ?Od{zfS57TG?L6_x~Tkvd6i1Gml#^IAEuYlKfNij zBOyo}zyhC)7^m<&IgKjP{Q<=pWFbR7hSiB#y%^k_az7 zmwC7ChL^!4=#&%Ve9s(VRp*Wdy<+N>P7N$QFSBE(c`>)1Xb9g|7pB_8Fqe>&=%>Qz zVnee~sUXTC0O@Z6^rb<(mSey!K6i1A;2(!>DD`G$na2`kN)Pe_<@1u9K{Zk+3u~)p zJ*arZIs0Vr9Po7VJ-95^4=3GuPC!7ZK$RSM-vtI2nP-fv;=DuF+ti5W^pp?y#I1aR zcaw0j(e>+_@T2)pEkC5cChx2rlqj-9w#&6wuXL7{TBX}eKcJP{O_*Ipni((`_o@|Z zhUuuB@=747x1SN5_es*0?Re@o{u+<4OS(@4r6FWNl#Dne_D(uEsN+x|WlcdhY)~ls zz}T_Cs$jx3ieyoNX-Q+%=Xmj;YJH-1BgS14Y?A5yCLx2^CCu_`iGQDZvUqj07;bSg zxEVFYBfWZxSy*)c@dXv~xh@(6c@6u;2`0gXEAwTQiq1sEMbY-+2+8aPttmTmxOCG((3r7EyZaB$UJ~5zSeYXeojdCtskqKh*AzIcNB$o5;MO=Q?<_icVi&vv-zvOC%(^ZC) z-mlMJi+jQ}5mJQtYo}d(s&QPj=cM;cp@gpg@)6mr6aX2-6dXEV@KM3-LTFM-!%!Ou zqcpOgEXi=+JHStLp?88xhEk$dN8s6;%1ONAxusGLRbSiLRnI+wP5miC8%0Ctwbqf< zzg(!@wN8DkUH@EUanuA$Lb@w08>+8z@*7EgiEz(&M3LXJ|h2wwGFH&FU7i_Yt> z*ASUu)T-%tDsHc?evDJ(DGH2f`-6oxN4`loJt?t?&+3k zSX1hN!n7oGV}-c}7MLH14`@jlZcO+53{^|0E9L~^jN?(sR!fM)5l=S!kFaE5rK{Sg z;q<$mq7WyvY4&)a<%~<~R40~GOWY8av3;@pvvE$>%vo9Hr&=+yF`dfMk~;L3brWA{ z-}n9{4BU^d_zDh2GW44_Zfgn&s)I8@*aG z-m}e@)<(CtFo*G;uwY${@rJ&>ipP{CojOO&B~IJ<=Luj;%}RyI}cs?4|V$@(aT{y5le=sRlo ziL;9b1*eG+&6y-T#7-zLyol{yx%-( zJfZME;QyHkSFts;SN~lxF(U&3;rut5(8TgTDa`*N5ngF%%U`k~`OVbU^-CAjsBHmt z;YLD633O#UQ)MrN!w^l7CJpWmmS0^q{B(DdN@tk_7KC=QAVy%#_&5{bIQ3D7IBr*{ z0S!l}!DGNQv&woxq$z?2F|-LBsyQuF^~w=LITfzi$X}3XYRp zFo4epTKwg*?}((+egG?0)LGEW6geHgcpfjhf=)OI@BL)^7c-GMPUdsL#GMLFk!2_a z2ZEg>m_QKVV|53mn~oWth)~^d&+Q;i&gXfVM8W5Ip0soIX#zBbhNk3OjOJ6b*22vx z(zc7b*!DsWnN`&1Esn^{$UiH(b9b2z3s8Zm5CCKZxTR)${Q!(C_elg?Dv;`%`NQ{gY7j9?P+&ZL0ae-pfu2HaxV(Fr%vM5@#6|qYcdm>_X zJcY5~GrM5~7CJpW8*}b+OJ{t)#7XD{&lS^HTM(*VV&YGTH<84)vq|?3A*;ThlVPlR z|7Q@b+;yRWmJJjfi@dFyGG^8LWw596P(eBJDqPZS*D&2NUSkF;46-Ykp_*}0i!GN> z1%#`W?{s2PPfhW(iYU80#Z*E#Kd}mUj7=8!Q4Q`uG5I%3@P^&F>cS4hmwO0bqIWPQ zP&u-KKEmO6;O3##0*=W===!nj`Z3pLX0^l4^6G+MYrOLp*c+nO>5rJljhMur`@jhs zM%(I}<)!&xYm>m#VbqV)kxBTRHq6>8V<>O^kwuHC2HlEPtaP$YSZ@kXW ztNYqagHNj#FW4`eJyn9Kd9TKEE%P{it>=!?dhU%cdr6Onuu*yN zGASx1jO{y|u$3P-o6S4#ztr5zGV=`s-Jxdf6D?0B@m>^Gw2l&xaAnzK@HR097s8)e zNo;!`J}NkV4g~5moR%76)Fhf+5>%1Fb7E{0EqT@*@zlw4JdnW6Lcdr1l@`Z$1P;FS zT;X#b)XL^vU&l^!FfT~E5T;c3ROAK{v>z7gvE#&ly4+;YQHttd_m07@NpFs(&s=hF z61S5;5Y`J8D7P^A{eN`=4Sk~=1}IFZDL#XA`v0$@)_>Kqpn(2MGUk^Ctcphn1VmQ> z1Vr%Pd@RfjT}<8oe^*HVgLCYX*V1`ItYPPgdU&=H=2QU+SgNJa%d0s`X6^o5OVO>t zv`vAKv&bMAC|H>3A@14zl3?RH9hf0F8p(BV2FHGxJTNEcI;Z>ibdP+^t6OxqZ}7>y zVR+SWqH6O-^DrX~jimQZH4bYkLL<$KN^el?%#!Y!v#y`h4y!ku@!> zFKhLgB{C^~oz3winlMIkzhf4-nU}pJ|3FlCKc=8p+Q`Z|0)E>G*8?AK#BNvIasm2_ zH~&nVzeL}72CtZYD5>p$n;F}w3>vS)c|=-J@1YiF|H3BAFN6Dz)t*tlb5%E_FHP(h z6p^jxbZ*hD;9!2C!Ds03@7eEsUl*>wmx+68eLGisM?(fbMbrlKDT+*5Wq*yh^ksDA zZ_)IF(Ou#O1 zz`jFb`3rUT#R^SjvKd@bUQO#m1fM=$X`%+`#`Lsw_rbrbCf2`6@}aLJ@MLs%?l$=? z%I5i9N*)<}$pZb|>~Osyyj(A6R4caoGTA>xQ++w}AAW%gnX$B#(mj!X9%O-JqsaXU z;J?*5A0d3hO?(=-bG9)roXOJ}ExUSGynHpbf>Z{dYD3`)$PUXZfQ&uKuM21C%+LQk zei;n48`eV#gwsbvc-bM)Uq3O&NX5mNXUJEETpY)TC2e8GUl6r9Fo1wYN|?uczHIdSXT_27gU0wwmxjaEKBB@ed>zKZeyKL|)qj$|9= z`R)DApIg;b*eGo^Xy&_3e=QcD!>!sm#^lDq!5Qcq3hx6?0$U?lmh0Ks`HTB++$Y8I z@Ed>Y&e{n^@Ic|t<-^xR>n(8voWh*7@qQSegr}YvC&30g8jGM&j6n_ zPslO>7OZ9FOyI@E5FbbCIKdQy-}txio?ZT^KQ>tn`&%S*ju`(G-pUG>t`OT8jSIO! zecUqI`X-4x+B~I1|J}kZ!2P)_Ye)|2I8jpZk=fm@><4O#kXk!#YlRZ$A1lYshsMIk zo*F?{RO6H+plq`j{Fq~?#SH})KzdK9dwOXPQMbZem91x{K zBL77N4XsoQ?E~=7LKt9t=5S6Rkc@cpJr+4~Lqo`ek-<8v>m5#zq0G6A%{f416>WnRtVVhq&?Z#4KE|+<& zJ;f%-Nc<6;jilymOj7sp{1_n$$bdA`)o7ZRZEkoUs-mEJj!f~(v83>dup{xTFbl1#bu(7EREGm@yrHksV|lxcpW9QfRR|qm7 zX1>E5(86hDc`uxfyfc47C=C!G^V^bZCyKscvLPbb_Y#kFaSqdlwM6Xj*dZDm6Nd*X z7q2a(c8?Yvd7(y?cX1`CUlmOnD7y!T%;SOZ>v)ki^Xc4tg8$ zPBj!N=Hu49hS-{3KW!pzAd?#(vk5IntG>i`z93N88eq`eRYJwqLLC@%z&HKNfQJD~ zcIdB-tgdvF&J(aSFo$F~*T@#K$K^>=Vfo9GcfOAt=MXuWOP^{Mv?Vct;^d9mY-2ap zGQ@6w{mC4{lVHwrKfw}->7ds&x~jG>SlA5t6D7r->1c?c#ffn@E83N`zoK1b`+$y zcv$+^wJuc1tioWDHe{)9|AXO=7=E;h{A8V-M3?AYKs5u}$HDuMjnA$3zDN6WikF@P znto^xL_VZbXZt)AdYKA_x{qdP(ZboMWEYi@gt&!7D2Ur)1aTPTzXy0MZqfD0u_vb-6_%SxbL_h2xop=NiK>^Jw(#E=elbc=E7$}@1!zYLhX*h?{)C+R zTz+KRP=Q6VkuTz3`v;J5{_FCK!Aj{sMOiZebD3W8)bD;$T+62|B|Urk#1T1uo#|$j zpisQIa>>)3KrYeg1`hC7lLOQ2nSYqlY|+O{^r*dFD0^wUtzkJH1xPYce5i8l`<+K{ zQzP|4QBIxjy>xjZ6<7D-Yq1FvKf=w=b6a%TMCR*8UqNhUVP++e5a2}?s8egpcJ)TO z!Y@i_Z`B!nw+(|eg`i`=xxj-=jutv2$E=B16&%GEOOF3KOmB>@maUnjC9?L=ml}AZ z$v2%`Crw>p(7Gkq2sdCBV|z1X7FmYTXHsEM##BuEnWf3>b4>)=d}8Lz;S+4HY{-yg zC?$#_V@yfx>Dhl=7s>P|o?G_UXMwV8g8GDF{)QRxFYQezph`4GI}2X##(491B%LAyHz27LzY z57X@jvlNu@m)}&#$tu4fK`SsTCvZoCIjTI}KKYvHBF_JdFvhVr^5Z0=Ri^7!Kr$8U zv{cF6OkNs01fOYU*4TO`r@-}>LZ%C!e5q2G1;hA<#dK7G3;>uCIifArs zb5&?HmM@KxdY=^@y?(D|OgdC7ICfXO|Oh!1*(sF($uojeg7}3X-G948j7Zrp`)z58#jXz{>rsSlI;q z>t|wCR|Flhlcei!hIV7w7rYYfWd{5YNyd}0<|^NyhL|^*D(&VuFNvqOG95hPDw1#{ z#pRS9jzAK@*s`qhb|a=^ra7Tnt(0{d1dA0+IHjeyXs`=oh`af5G{w@ueFmKn*Kkk} zH=KgvW*pH~Q116n{y!rcn;!N}w%!$Ho1&MNU~V+Y)7#y9K@S&v&SpFUQS#=-KRx?+ zVU>e-cnde}dX(x&j!Fs8Rfh;lmo=Vg$Dd-ROVt)x#Lsp`olxzf7fmo#DWZEhH?NHO zky?)w$Rlbl1Au3w^aVZ;lsZp|{pt0g==UFeG&ojS3DrW9_$&9PYuM}aq%4HacgB>Dnn-Qa^# zXn5DjU&XGw<&@5FtIe2kw{-g2*{hp^{_g_;JYyBO_U93>tWI{+Rro> ztZs;fq`}d4U})39rO|vY%@i?Tqk^+&!6aKM!tn*q~_%mD@0h_DNy2Ck`rPiaJy@V)wzi^ z-P=OcEiTF5$hHQh>_yhd?ra4-j>JWV_svoW)|BPY^%jZsE_BF|`4EC04kIL;2aBNN z_UvKmQvl%#6#d3>u@}Z528%cftl_#1)wgvs7fv2?n$u$_YeyMjp4g~Kk>a|!RK?bv zMcX;d#Fw68qLyzJ^H?^Cd#u#RgX85*U{Etvduo<+mG49rM6Rms{+_WtJR_hUvD;rf zBiTzW%fE}z+$X?!4FK_T!P(|t&nZE%p#E+u&kirzC;$X7gKGVS_mLJ5+w=|kR{bY^ zn;xqjGivKc=MOQP6P4@ge6rvzwK&6<4Yy{HCKz___(f9pWl8|={LZdk^P#&!S*PAo^i0O+?F{mOW%$9oBU18t+d z%}pmIapJc9!;2_BV*NPbi=(L6+n42c6K8)n0y3sQs8d$N{Y<;tp?tlc$6AwcSqevy zuPc7u3jsluckc8uvEZDQ?FGW4%@|nZb+icQ_m-*Cd6wGCG#q@b5X&YM!-XZreLHjy z`fLt^jz;g-Jj2IVcIt+j{YHb{qu!G$3>Lku#Uk1y;BhRJ7sf^Xyi$5HdtS3%@pj4g zbayx#BA^i)UgCFaF{^|l&y}|{8k=+a@IpKMdV`E(=Th{aZGC*T%PRXs<{fS1&O1`0 zbW;7lB)-rEEC`>AL`8)fVV?Il{)`1t*RnC;t9^6{K3;B|0h1b+G|@OxjY87~8G7}j zH%wDNgJ>o9Y!BJ#Gev(sOf!Nh$nMS6-Y<84#&a8_v1I+|Ko=3k_Ja=skHpNxL+hyw zs!CKq?#%?|O&_qfrHzjRcT?c(SOX!6PNUxHU!b)d4qrXNJD(i6gUxf9VR0I~`McRc zCzKI(pi>7-CL?}xRWFOHvUySs#wXDFdHw==z=;a@#kerBYoDe?#7DOshUZRJvt_$+ zHZ0@5<-^WU=eJ&}XQI;^0{<|gpTfwyiwLXj7(MtSdcIT{W7A3P+JDImex8iXS7R6+KM*Y5(HN$xK&F~2 zkveULMj5e>k&;we@r0z&ut#g_GxSlA3u!=`J7u;aURBgB(q{2nnMMbT>htA%v<|~; zwpls4ndEphFl>!>=v|zrH`;tmZ(qN*eQl)r!h{q%r zH%6ISo%4J>W6CyRJ#2Y7YPl9$`vpgv`HiJ%+s^J!kgY}6!YWsidj)|zz_nc0nln|4 zZNftu*{R&JF`)Il+_XUzj;yPzcw2#O3~Im`H-dZ3bqebb07kq8giaIe+&A?x^!H%hK*e>-5b2k*>pTP2b6iFFni0im5%vCaX)mc63tO(8%*KP^MEjuOxc z&99C@#U5P{pr6OB(LuX;`&T4NFUM76G`|Q3^-JPuB4kwyqY&OZgiM(ZKdJM~MY9aO zAb<0Hzea|5>~j?OT~cn4XL7Kdre#DCI_W}$5_4395`^`h#1)9G)qv^AMkAY>qfY!g zHl}ZD1l_iX(7{QdGtF*J=sKc6lNg-5WqsYM3A97Sd}LUB)UZ?ue$A|uIm8BiJ^T?q zU~@tDqC;;v56EvKA#+(f%fI2QG={z&kA{$9LaQ_#hF_mqQ#p~w2W)_7(sQ^MQYV-G zdrg^k%^}yQca9)W*-6L)(C?LgzDV4|>SV~@w$Mh^BH@d8GUS-xLhgMXzUq_rNGklM z!+=d?xjI{lchxc4IDWC*usbV}^H8#tut%WsG;%4G!9f3cTmMt%>4&fGP7%7$8M+@A zhSzu51U0VHJxG^`kPt)pcRn2@_BKkf9z!0$l?mc01CFFDZ0H_3seJ=QHQ$-1?8vce zn3c37MSH%~QlR=2x*bQJVCWfwxoaP!sy7t zZ!)^8|4&{&=MHMGO&TMWX^C#Xqa>7Uo538jDY135A&)b-C8q$2CHfA{9kzL=6ymJF z7R8+d(@Q;&v7Z$Zq()Zw2FqozCROKup~()tE^xR_+S%W8?-h8xom-KJeE) zY~recJ3^F7o{B}UU&2|0t#XP6E$1FFCnBFjWjF+VxCU>My>qnonO>$+7}=)DQ?bnc zblyIpB)<%$omxez5Rm{OgHkkK$UfHS1CT$f=$B_NaXNt<-`~d;i7a{hJI4gBP9fYu zn;iob+wPE=M0(_&3BvUnWw{9kN)r8E&BYzodTQd`@ee}ec)#oVqD!g3y$5bKC(r{V zrG&5tW-{gYen~Gj*)BgzGKfn>5pa@kZ0pLqMRn0)&{4+*lhFKRp#mC6;F?jmO1|)> z<1nIO!S%=h8m5|<_g;A^Zx>T`XAifMJE#_;e8f3Th=5cW6W`o0fa&zq@nnY@Z6_X9 zzcLnu^e4H6)!Nxc-`!oh>Se2ME@$iT82?ppj^f~KXx8xpZ2(}{r*zpm3C zV>giKP@z7)#@&4ivkiJZ^BJDMshWhI3~9N ztI};GS-j8kw8zNd$hM?pay}Y?ML$RqGDqT6WH8n_FV;DkL)z#{EP%L=p^>59cw+eO zs-SwU-bXB%bL5a_Jnl&T?h?aj>XeO4l_e@+R|wrRU_w0-#qwr$s4~oJ>XN&Y*NYqL ze-EszugYuQ?wNjQ7$TZuq#+!48|C)>gyn8c)n!VZy!0a`{7B@V^}NE5KbAM=ej6=; zuOHtqg1%{gJj0UMu*?WSm*Kab=dTM_VQfP4m!+zP8*n9Kf5Cn1xc6}TMJ$FN#Wi`S z=@&>G-g$I#&6{2ri(Lf=IXqF6GIjf;giTGm_|Cqu=1_1>CDk^jW}04! zEfH0j$iU8F&yef;vt|jMd`bLZa6RBFy5OBmgcGVm4LXOv+GZN9dLG;kPAn0Iu)Hb0 zY$0?B(Z$P;;QU}|eKth0Kf!mSA+Rb zY?gEi6_T9Lvbu+|X`gwrSvnb=NtW|Hl{vO zgd8g)fv`PK>xcv?f&%4T0pBixup^gV+~*g34tl-- zc7pWw@)EE}Yk%A3#Wj#HT^B-~URNp6!%RirIrDp4rK>BUD5$fXCrFQtRL zA2E&4g*z#1i4$>Ix=N1H8Vc*A)EQ2~|BPY&(~e;uw8T|>%?4c^LiP|AZ@v1IJ$B4` zKrS@bgwQJ_T6a`s!l?#AeaA6%=Xz#DG`6iPH!j@GKqkRzGa-QNK>)n-*8ZypyD5I# zaY#LA6fKOGs)-V#W=8ygwyU1U-dp&F%NsO|!_qzn%znEQx?8uF1UHVBKul1{PY=w$ zI$qhG?xLKqyYy~ME#xEWweohW9Bdj9YS*`M>;gxG_B^?9UyJH&z8UP`HNg#qt&z%h zRj2p7S}SlWN~>c@XYpB+n!C~Gi$rQ%-6U4bG-bo)SWM8{yO73~!|YMk>H(OfJa>!O zf$$yScU}2t(YXrc)CD&Ni&)`qFyfBd^Mlt%-PuEvpzvP}H8-SWYlBg#~j3gOGrpfCQ-LIT^S$f+HJREKU_U-SwcYxZg+0X3{RkDC9ex& z`Ej8Aw5_?&54ORoN3MCf&UcegDZZPAB}i;Cv}1)ApLgX!-9ENzYwQ~9pZ3g__d(K< zyYrEenL^IgXHQPJNWTrX%cWQzD7##toP(I^ksBo}UU>Al z)$9RBa4exyoLIT~GVT5N^}`(DfX{7_v5NjL9Sjn+(XPCrrgHUBj@YCRjV02H zkFTqHL;tsA5kE;j4B*9@N3PQ#GwA&BvvPJ6fFV?xhL34*>Juho@za^cRAl3tHR)>G%x~=LfJamcH3E6 zYu)jq~dv~iaMBUAN*gWI_hA=~`!ymSPi zoed?l>eSz^7LqU8iR2TM;a({skBWONq-^EDjA#O?1IxJv%w5RWIZl`VO4piHm7b|r z!%=RurWvfepH|Jk+bxMLEbHoN%@dE}Tw&FoC!6flQGRrYVX?%NlG@}8iPJ~oc;Ion zKChU9HdgX5N;QMoK-avl-cH?I>o)3^PuIS(&8Ryr*mlKRo7^u_NnD~Wi%WE@*xI*- z*60#nV4pV3=LeD1S~3>;aJMhwZ@}) zsc;aWR>_T@vA{+NPk! zlin&g>VN9nMSVeNn#Qz9z8F1JV-}SUJ{}csw>HS_%YFZ0t`FZH7BQ?YvTb|q8z7~u zs{GOp>mZ14YjCBQdp0SFrfiv^AWN6M`?RXB82^bVvAw7V-TzY$+QI!_b<_Wa^;?+Q zIQ$3F|35HFODTPJ8w_w^H=Z!WU6i5*d7^!NjD$^7@i) zO5Cz7xGlp=vh(s$pE6QX?a^lEv2i? zJp%|oTR&}F`qlGG#xoona8rDbeVCwjnqbc1q2K&*z2h)BuVl%588i{Yx*QRTnsk-9 z{6E;pdQR(J^G|LUx(ampBlYb27d}M+*q<}l<3%7V=DLHrgBh;#TEbS9wJEHOjjj^m z8F_k`xEAO*!P~ysdqIBaL_F1S^>SQ!MX9Xt5WdtjX4jpPdTxbKLj{c>dq$%CadG2h zP7YeKt6lf%-*#aEr-cuAi6ue;8nIcqql0S%U95Us0)joa;$TxOAV=`+g@rk56y+KE}Kw z2p^?BT_YFo0VvW}*-bnQ3TP<>x&4VEz4T){MBwI|cUpGYpWP<89vyrJO&>uZhM8p@PZ zjD9{6l1-FR%l}%EuSggqyx`y8-s(ja9C66qz=Ka31WE*L_v6CrBPi^C^gMd?9As5P zW#?JTtic$qUV0%w()?b_W6(LKDMbpSTrUb?N~W00q9KeSfj}Ku1KPWY8RQJEyFGnE z$cYggrZ!PNyY{)78Y3T)v6il{y`x32=a}Q}tnB2wZFd)40s>&9rtQN!K%|4n{Wdg_ z+lTRJDdEIyMV@2CNsLr-G@9;G|waiE1(on{DK)%pH1_#y2iN<~iC zr7f)mEM=1g?|^C~VXU%|(sPteUDR<(+EVTu?8q!pgf;oo(_JFr(xw45mPf^hc)Ytv z)Q1r=hZiWs?jxQ+o#Lc!U!+fJW0qWIp{F|^f{FFsl~x?AFTLr4Rd3PX+mhc@r6c%b z6#aTPn4{A%_3MPGB7-RLD{oHFiyeg=uC?yO0U{`(>tF|t`Yu{#e=vtE7m7O+FuJ1d z^inQ%5C=a-QZV0F(T|YN91{pN|4z@&IfzLIj=c(9oL_4dBYR>HFK%KYtM`9^lGcz_ z?11dHH9|&M<0&}g_kywnLu8wM{mhgBQ&Ns^1TQ)&Byum}!XjYr?t0v6HeN8Vz2N_# zenp<>&#_KA<^UU2QVM?}lC*&&6l2;YZ<;(KuzPf@i0SN9PLXq!(brH5Sm%JfP{5GJ z+-Ks-3q9u*CA0{z<`85mUcgkO&7c8H7f9^r-~IKTRYzby+M8-G+`Pf!7LoAr&<03Y zBgB2Vr_3!(jLS0G$=Ub2CWMl*7^Y zHB@J^HOe6?oT;E*FO&yKWK%`_$>G1Neza;)p5unqdMzJDfI4@4M&Dl;S_3gsj)*)j zEWDUgqts2GiOUV7*8(V3RW<2x8t=G6jC%ICAKT8^oLfbnETFGOV993x z!jT#Law0!{Dtc*HF6thxqc-x|Uv!tVg)Clw;iH$pqnC%) z4Nhuhkk3Oq;D+#9_-CI3(ldtBlC2PTdh~#eyZTAe@@vz%oFeW<{ws+bX ze%4d9ekGoEybCI{cqU1nY3koL3f1S0A1rOySJ~pszPxt^6N7(U)|M+2SI$un+`^bz zCN$}!!K=VYTRN$#s8R)dor%$9!s)YPyIC!>2`3UF&db{W>EjeB{RUEp?gE4LLa)1%vR(fYnD-TFu^h@*#&4WfP0 zlIPd*7p8mWhOTT?cl$%Rp_&RL%U(O)K6>HIusbeB`Ks3|h{2V5+Sh1*P6)HJgK(KD z+C95j4indwH2^ZRVVc3ptMS?l?$#~YM*cc-6CP@nwst{3yP*_QS2rIfz#BK-n3c~x z{utOUe?M49hGtvALajWaw&@i)tvm6z_F{I%?8jF2c_~^z-lBQN3&7!`c8G3@?2Lka z#7|c!tdu2(f#@MDPNQjQjC@Mr%6z;-*+(Zh_NJUId1sizRg;?gX6FPCE-3q4Uepi1 z!sv2LXnJ(l_v>tI)*C#PKXYVoxSvYVSM+gmmdO{qR$9Eeq{z5mC|*Z=Zx$n2vTUHt3wVnhG;?6$4p|I>fbklK#@!aohijJ{#NG?l7iX1}y3 z+3c`@goJlMsa3<~aXY!J(SfrX#QM+lnxuBe6`QQYZYtNw$K96uiSq$REFB9JC}$FU zM-aF-x5=6XPgKE60jqWSKfPp=rZD)EFKW`a)I^43x0P&VptIVdl-0m`iN-jthbrc5 zK|c`+Y_h6#fIGIL%Pb2-JlHg#-Cwz+_%F(5ii6G2*Yp8KqVKtg@UnePC_1YG*kepJ zGn|eC_4=D6u^23Tao=8PnJTyrWh^u%xFxyQn-Dl)@SV>x7Px~)FH^KmxzUBR-I zTcq?OkuCK-AFKg~?~-+x-?1@o+qNwN*T)usIuUQ>$qsoT~m|7{Qf_1 z=KtD13)u3Mw*U6e6YPJ#f1FMKpZ*88)OMmX1^#JlYB4XPh2j^eeL0|-rT{5IXMzi* zg)~VVlBlk&9oF-oFL}60mIydu&v;GF1fvJzl(1(`E z*y7Is`~CBL1o;`%Fd>&ED`?S;ITS*ulTMa2)o-HtOhzQ%a7;4^8}G_m9?76<&$~@) zcC~!vaZSy}`zvEaSb~RLsvlvEt3ev;rpE;aoz-gJes()c4>nB%x<8iZpqFtWGvO&5 z)+Z~`WNw!_>n6_Bt?SX|cboPGyLH&G5X|I(mCHog&7c|T(KpxQtwW&AN5vq_5^fzg z`w3bm(YywjJ%rk{3Sk@IdSCbBt@g(jlc$5;3_#c2v^EZFPDzOz_Z=9#>)qR*V;$u- z9aqFv-t{b)*F;h!LJHV;7}RGji=~oTH$EBli?O(Tqz_k#yKbvicwniH+MLSo-eSLD zY;t0eis~kY^5Yl45@jUYNjKdna^rt}V-$YT_BVAgmeZ7X$fdZV}XXJv^7B2 ze&`ypJG7pV>G!%i;Psxs|Dcuy%N>HA%_mRL{-Nai-s>T~3Y<;FQubhDA{jOZYSSgW zO8Mr7H8DB)nFHZ`4I|0f)0n2+`ofS*2H?0|IbJRt|_II_0P!9@GlzwckLY)8|VL*L{qG${a=a3 zZ>5%SF?=+wgiNpl1e}G&wUJeD*`vPE&;ll?B-Vxt_1`Ei+?$mf1q|_)tm1)BqbQ!w zY4>Zm0R~JQ0|cf5v^VhMLl*MV5ej~YzL_c~Y1{3Df7O*F{?vX!qyis1mwzfIYIUr< zOW<<#T8T0$IGBH`$Gx>PQW9g5)AZ{X>Lh3yG;U*V;@g!>6o9E0UF4$Z%ue_pxK0PGIaKAD(;>EsX6nlQ%}S?oDrt{% ztruLUyLJ=n&Ac$*?ua<&Pk94PvQoz>V^XkaS5R;I?itX9b^Qg`INqs`(GklW$+2R0 zyH2O9+Vu+YWMrs&l&in?lniDmxw7t;H_mcCNPalm+H!Q`YHKMoz8sjmfCs|W4{@{% zlx^;YPuQ8goWDP=hNj-aQ)?c`3{2^pUAGNH($Zn{+DTa_EGjk5OP*~8(_duQIlX5- zOH#gb--KhrU?6F9x=k?W<-BF2WR{!RG`$w9W1mUpSzFLsH(OMTo^@LB$z|i@xJ>?9 z_fhvZOd2fo0f$z%8N|{@md2q0WvQZ%V#mV8bJ6)ZxaQY$bLa-|i}@p|r&DxieoIf_>7ACCt`j5(w+k~b>@Cj*oPcKX`8N@i!Dc-Gg_0>> zNxd`L>*=z=7z6NDij+Sw)GS1XkHoj%e0`w~&@ zoVEY1SYfr>9Zz*8?ci|g&T;0XdydoK7**~kGTw9beVvsT)@DBZ^tNvBYXd``Z)nZKi2XoVA%Qrce9Gr)N14-?^ zqaq(?Cnw~;t-L;VBMqq*hP?CB)?7Fgr%nlB{F&?#&mVgpc=AdxVEB27|;9-LmG^xqUv{O@^F|GR{d7$4zvKGI^Zyz{ zVryz^@8qd(V`})Hf{*{w8hjI*}}rwU!pfft#av_FOD=?_wkTlKK?2W&)zA5FIFxG z_m-P`WiYqTHg4}`(zPa@IrAe#)ZHw_T$W)C`bN%1(+~acWID>uaB@&**?e0k3ZSHK ziN}}X7I&Dzlv`=ir`~;`yWWo0*@)4{oR(N{7$N9_mdBHpph@9%ZL=R;Pb)b#mvDxLl2RfJ5*Rw5Ou}1YOE3=T{({; zZz4>apH?yY79{IS-&Qz&xK&&(m+`5l_4#(v)q1>fz#ljDa4?sN-cY*wZPt(Jv1%53 z(~RWGm~(P;A6J66-bcEM0Oe{|AK9bb*cM%=wr#NT6b`pOBk901F2NAIz2~S^uKZ2P z3Q;PxRyV`&PAhUdm_O^dK)#CnE>)=e7$^RvPV3_2<@WjgNdJ8&{4=n--}q+6;K2?D zUeRKEeR&yk$67E+JgMRuT}=@_Vu2sUhD>1X%#5i&sO!i2cM6Ag(Zh-JqkHFUdeGb( z;fLv1t^%r~th!~1j$r1b#yA>*tvkh{qJ7Tzv>dxh!9%Iik2iJ3s1CIDvPU z-wy!k@D#um)4?MyuIR;>w+q%-Sj;z%+FUPvSK+kyD_QB-R&g?ZM}#FTx3M03n>-DC zP}rH_&7XNJ8LnTqWJB^dIVHn6@Pkzy?fxJwb8f(y;w@? z414TEfp=RPW%8V1frZkomPj5fhqHc(56ZZ|(G3hp{B-ig;XS+Gt-s%pM9ozmUi-5{ZWbhmVOBi%@ebSmB5DBWGsEhXLEA>6}PKhW>~ z^xng>pXdB>c;DG;*38~Bv(|d)6D8yX$Tr1`w&%dx=eR7)7Ul+5$6kr_uXQP-S_%?H zPRs(onLB49nq21noW^QeYl#WG8yXrgjO1|~#BD_iep`r>tcWysA1)DeNC|V+IsxaL zJo2r!A>ADNt86HtJ8}U|T;XZ58f>~Qr({_77xjL)LnMM=bOr=dt=d?;hLF7|2$$Cs zB+ehV$25ST=Gc#6(BUkKOuC+@H<>E6!;7n|amIZn1Zuy zqstW0@=NrT^HnoPYM^-DPosi*`FerrcO+p2sPnbq!$G;Hsy%Dy5~>p1ZK;HM*2zXb z{fVjKU5Lmm#Y@JI_1?xT_vk!KQ4C}L2uaXgAV~d*IxF$Zq$u^4HzobNY-0FCg%ONw zZ%eR|xGhesd4ug{yZHLbgtV@lwAh)I1c1Td#i{O4v1y!Mu&dv`G7V(zwG`%aO0aNi z_>wHsB<6Hscd%CBkM?XZPsg-W9(~zHjT+3+1_|D z4m1Qf-x$=LQsR)%6&u-!`@1=sdRP<_W_KDDX03Nnp9)8% zN7STZ*DEC0V~2$fz}c$Crz=z~TvXrDxfN#ie_E~O^H{gL?^ilnTU%L@$W;Jh6t9rw zS2fR|y(VuWk=!m#z)13-Lw?D7aynB zV#$BGG=9&vB2BFGt>%Jo-x(ZsuXseiyF{#5?JMw}iQZXcOzBvR))45lHQL>z>5JaV zfzpgth1o_5yCZWfJq$Rw=ExS?{a^#2G5taIe&lH0THeo(-hct4onE9zLaKg)x(N)rX+>f_{W-lxd zc7%22&ysUVtT$l=Zzm2U?!%oDY6^5vx>6o^yHX`|+k#`hs&P}Yvv@@8eo1_f0ZcPx zT&u_}tv^wX=_&fMtd}V-Qotilg#nv9c%%sAkSw zRBdoTcMjbYl^TUAPUyV0v_1`_M$Q6I1Da$Q8$aTjCzRYQ45hFWr(J6gI+pGvTL=}h zI|!DOLU#%X%PqGJR8EgP84Cd}M8;R@D1)0%SbiQXd*2L;xa-xjDlB^&-RP?trg-so zq`5Ow>}9S>BE+sdQClMG&hjbfnz6zMs`20zUbQ1H*k^sLP&Lpu6ydBe;kr=xq6(1n z0p^XtAK3s8F?&{NzZ9qqbb5;i6rUxuFJ0Bw<{eoaCsSC@tY5(Xt#YektA@zQw6n_s zcCEn6<=P=0$b_p=s2FAJ>RSI@6(4#vD6Hq218wCTy=#h&n#!kMpAZuozGE9lNKw$R zv0f--2bR;?X#FlcIQKQfSZlJ?+v}?WbKM~VYr0*b%EmbRQ7v&|v z$Q$$q`*Z!%qIQhHq-gC7rF6Cl7l8yY$NhY~gQODXWOf~=zis;gZR^P>RFcE6$>jyL z4ZT8b6T6Oo5ke~~Gi9Db82>h412BR|@&$OcRU{@Di+T!lFOB_m`Lq{FA*NU^u)|jh zyH;GcwPfba2ffV(%jT1acxrP~818!!_Osz#ccc&U!K*f)QVW#6#Ek3qeu_A?R>qvT zlsONi3t?**RC^AA;9ST?w$2~ON?J-)#fiydxTD9*?qVW%5MQX<>3d|0Cq4&DIyG)H zp@4TIwk>^+8geb=U%AQVHD9mAgPM6=m_jh8jdpO(;qR1>L1d%P5Nqs{-5DENMnKX-1t*of?aYM0_A~?aI?_FiWK)WfCdf@^mOF!^-g@Ujmq^w=exql_E3prwfgQs$fb?tyaI@%)D>|xQPe5Hv>_Gmx>P>959M6FQLoRpl@Hterv4>7 z`Kr~Rj|TaOTc_vb$_~xjUs(1wST|OB-IU5tc)PW|Obdnz5G~!RwGiaHyVWMpy-b(> zVt-!t0unFhCQ{>`uj61%IA~ET_J~SV&AGOT#DomDd^^~3w!enGW-596^u1N}J+td>oq71_kA$Ppy4_xH?Z zS>PYg2)c7O9T1LE6{qsr%*Dy#UX27RO=-uaZY}O!O8`>{2HY3NcP$@I!gwbEslDcq z8j@w^M8(qF8&?V(fHZ%ep$z&4?h6~^BEiS%i@nc!QDxgSW^68by@FEso;1fc;rixm z(bH<=az~x#3F5a_iWyqUP^GEfShze8EY9a7@%$R#70`}WOrc&_J8`Wb8Ql8spN>=%i zatt&tMxs$j*O9{ZzupK&l-S@EIUw00l|zlh0iENQRuVekAG_PZ$J00~g>vaP=TKQs z7_6~<6Cqn7>i4!SFN#@tyo9|46D?c0Yorc2zRyQ&sl~?BEoP`FsuC2PSX8!-3;m1G z*)UA*0MJa1ExZLzu}Y>;LcaW)sV*EX-j~LXtpr?dFA*e~%=GtdcG;gf5PN9VqYT~3 z(!7W6?0){a|69-pHINlG7E474bU0P$G!y|*^4d!K_h0KvqdjqPNvA^j0ztN*KDWNM z;78}JqpyNLl#$Kiw(T{VY|bvuw56d)Jk2S2tCv_N8vsLH@R9=GG3PcjG*4r0Adw!w zA3}&)E^?u?$d`G*Oh3&|L9XqUD3wJLDR&Sypd1T?Z}c+(pA$qdIAkboda!Q@#qcs} zLKPcVuYeOehG4lGHYdaGt|n7zod$~>d+lT8k8U$RVp%EMuEr>gC_~9$nOx8mj15Sa1l$WNZ4zqlTBcY9yOR;}w8ncp=g%DBl6b2V|CtT%4B;LoDK z1Y?59xk(yEa!;?nMJj;H;5uIznb)-J z3upu|DlqPfNU@E^uJF1Im}#1C^N1Ccr63Hc20s(Z>x9&)$wbOJRjbKDw4JS19gK}- z6{UJ?SyfP`1{V+LORDGNe!kjL3*7pHs2hFRYq-mpJcBeuXd4S@Y!2w!lWRlPqvh;Y z*9*?P7#R1HdCjj+H(XhlvNo(Bh1ryr(eao(5j5e=T#BvoUv!t0Y`s-Gl_GZ~h!JUfSi?DHyDWhK|ite+L z>)Rg6HZYwjpMt`3SGq{N3}?okbt^=}I;4 zQjVx zvEhcM>Nyvp{dWF0LCWmx7-a0tY}9nlomVCJ!co!N2Q;cx(GVHYuZ@h1jj_#9?U$Oe z>0a1XZQ#rBc*`c2<7=(#a)rko_GHf;UmN?LzjGX*Q`sD;%b+ytP4;y-gv?y@JINb~ zGn&=g!$1%MKH4h-?(#!VDIZO@h-0J=hQicN80!^e10DqV>OA5XiGRkBmyGuH!zwQk z*uD-9#D#&azi5#QD}CCY+=$e>CN#FjTu}M#sD`u{h=e4$=lk^A0~Wze)!02Se8&gT z{rF-KtD;f8A?>vWrrPr;Uk`ubky(o)>Y zfo!cKi0{s4KK+9G&8Y1y-Y(=I4Xx;!kp{@I(rS}A(;FUfG|wEdz{0~OP0lyz=iw+uu4Q0Ep{2R_8a(I#g>Za4z2E+uFTM5L8N^>kwfA70sMC@6A9pX|Gz~ zH$uDlzAn=Nn=Mgxq(}>N5|JV!+y0F-Hns&h8_q{W>BO=72>-Tti9m2OQldMfGpDsd zN;GvAWqmugt*O<$6mCLZ@0pa_%1iKhB zbkXmX($Yhsf^=0PT`bbmL0u=p)DQ%8Qzq0Z<=4w5^9v8k8K`eFJ&1eE(@NyHiBi|@ z3T&0%!9WTY?@aHUCG|o7Ooib4PM1||gJZ`~xCb!lu);|y zuqM~&C-~>{sG3EdVbo_jF_Mvm9|%^zqDY;GUjVbI5WuZ|P&xSH+4p;6-ZEUu959Cl z9NeJ^P+$azb<@*<33FiXvYH$SYp!(|-nBR=w{3criawuC6ToOJ#rGi0aHW!rMTj77 zIHsUgK{%XdrO8e)&GVo(pH|4{^|Hl8gZU;76VaWU$u-I`BfZHI`b|(eyRF{D`cqbn zla6Ip%cQ_8h#F3hh`DX?%ViA9dwEwTmz!2DdcFAs$!-Vr;P81PySqxPzW5dXgT9uF z(`?iI-kP27(4^3O!dGl`-K90}|O;iK34OE*jl|Hr4;HpR037|wD zF4j*~xBjrt8&hL|=*n*$^zFTtvIc~}Hg2exY<2Ey=z}|A!mOg&s|OZSmhG}Wd%H~u z(49M`tyR7Mo`*1Nw@d1{UBb<}NmMD;BogGBW%qVadPufo!?vQ~yDyxU@LUL;#YpEPy4xmjh%TWLc!!^f@Q6v8(319&CPNmNIw!QzOud);nW@RA%T9lhrw>i$tA0T?D!W~)25(`vQzX-O!AYvjk6BdGEckLG$a2BoTgHE$x!|`3%bFjuP5v?rq>}43r~hte z&?ya%x(>SZut7=H;pwTTU`6(H5Vim{YxF1D+!UtoTE0@n+a?IiRCEwcUH}pB+_Qt~atY4#`>Q(insX zgcxd7=WS5kA>L_?yc0NE9H)F0AGi)i3=U9~#qzO2WI);U__!LC_u=y z*qL!bd7v07@L@CYRgs0TMQ-Eat#s?5VzyucA0wDKp>VJ3VQi#0K6G^zy0!VWwG=0K zH|!;Y%Q4G4UU5@=!5LCc(>l!MT?p|7;Zejr6O;izQ`S4wV{n`AZx z$!Nz?l-hbd0fsHOVwaJ^72z3u1;|~cI%5k640DnxKcGDowB4!z3M9S5q)!I-Y_7Rz z-=#f(%xc6aNU4C=bme6;u?kglKoxP(&b|=JRuQ6{`24Yx4o3L}(>KMP;mCro#adx_ z{Z1;=TE$AK^se>hCHYgw_l)z;xbH6H!61)n8GwBcT>}bVT5Ex7Kn1KYcd+R)vX@$B zbl!tW_3ORb7MaFeZ{`|(0YrnxA`aR8A=8h9WZ4YTpeC`cp6Iz$AmN(Y-J^1!@qiNf zPN6hkSpeZ;{BS@(SWh3{ZwxJr9;OmQfS!z@rII2v5NHOVF8;&yLv8%VxEc6D(Wr+& z`o%FaN5*U}IV}N`YYw9`^DDy&1x8BpPyZwIr!suS&Ds|Y}OCyPrYQ6KUrPDDM@p4%R zWhE8HAl&0T(d*e_Yvetm)IzAGHYy?ixTiv0#{GBVcMH}TgviZDxcA{Wttbtmu(Tq< zZLw)YqLBws@)vutd1i9U^o)7UI z<^q90bkh4Je0D{WTCOO}AYIbwa1tS?oLX;(JZjKFQ}KwfFEaDp@Cyk;pM2NlA>i&R zx>0>ItC-Xkp{2qSIOE8mElhM->L8I6PQ?p_i92=SiYp+BR)g$PtqgJgwij)i9oIhg zD|C}B_t*7{BF%ob*Y?N0HL+Gv-rdI2?Ly>tNqJ-t(ffr(I-8gwBikT~bCxSd!tl_pneEWtc}vYq_=roYA}jOt zmB?QlODh?^W`;m8~5ctB-lB7uOW+u|46zJ+TX#9H|=}zbfxMs#U*g zD>IWQcb&Bra`PGbr7}&#Way#$>2u}9A$_1Mx=_neU^ZM)p~0h z5EN*-AaNI!d9qkGi1!d;8H2k-@fKj98lK&yqQP2HJdU!U>g)xY5?cW`uhf$7E^A)+ zWkpnb?5!O*pMW4@?z6nbtHq2qAVQ9tw|v+2R+zKb8{K>y_4Cr1`{I;FZ?Vg~(`fy3 z)5NOTL3EFPY{v4V@~BnnY#5cBxZ0h@-dnwd!<%Yg+by&g(m^d2tGRb@FQ{o7F9^#P zP{@+IeJO}t_)}dR4+lSeiC0KJSaiYSV1nO6=*V~D`_|Y^*mDDqA6umx=Qv*mx^Pg# z#(kND$BOKskLjJ0h`)(m%j{8xf4sm4t|yS3yg7;&<@FR;t9xg_S1Bg)?u~}B%n_p*WSTaxTPTueJMu3U z@$qws7$bugnM{WAL0i5Oi<^2IgCbYeDJdLZ*i+Ft0=I-1BQEgBjgqx1PDSUK)$%@ZO~N`m#unzLoNT zN!low6rK}u@a>W#K>~VmngR{U>gTEa#qU!zC~sE85y`!V?0f~ohs09s+b_9s+1!gu zz+vM=m7ydM#`dhr2Z5NZnP?Q8^Wa5pp%}Bbt(<}WyxVtyg8^EUsde#_{ znxI|7>1#1+8t#?|pjMR88EzKnI25|1j%6Jo}CLM*8dK#Zb6#=D2a z)_V93S$)&db zE3Ho=C}pyCUb`XVsGLi5Yvnjw`6PMS(C%>E*|Wo~;v`h>mc7N*5-h*ZO9_(|<#V-? zb!oA0T6G3C({C`kWi7ukY=qJ)4T@xMn5xUek@V3P1zmgUdIo-K)e$fX47nTc(-wHxb3An8NNhl-T`yxezG2(nym7jRj|}lEoJ3C zKS&#yi=nQ`PT{iD`0n(5I;Ft+36IGl*)%GA%zG1S@j2+0)xLEFHin>iS$NJdI4BM4 z8Xc6r*HCyO$kjZ4qZbWXubOU&DZQZX{rkOE;OmamXI$e=M{h zzU&3FHP1!mC{!$%46$ii8IbHbTp7q~)xB)horFMF%C;O|wwl!8@H67@0OuHFWXiI` z0N1h2#PEt$9rSL-PYxZkC*|pX<^tq!r@OT%^)faQE{cZI=F^J9P&LCh& z@sIze96hyol8~FrtE_OEJTqLIbwIL9J;Zw@~z%csf z1|5nw*r0lReW-*HDm=bzZ(p!sv}&H|R7Ne;TE0ZZK_mEHiL`)wuQ`Zd4Tp&3@>MUT zcSv-@lrdw^t|49XsG}e;HF#)9!ixh*9nY?Wg*7F|s9XI8{QE1>Xac~|my|DslRod~ z2&xrJI{qu4kDe#zyxJ_ItPK~Jyp`iGtw)>}iU+<92h232-+HTxlTcQSN=|xo>(OoLj(#_sRF>ULnZhVy-Ja^%*QgLL$rf%zQ=*XDE=rHo zfe2rEVnnp#=;*{-D37sC*R9PQvvk#$eoT-_KEi#C#IzD-bXOE4CQ-|(8sz!5>A1kx z+muXCD$md+q1`pO&(&EL1U4Glb|qY^mZ5rPp0rGH5%sxrdRjDtNCq)PC%8;P;6c2a zw%k(aWnzD?LrDOoMiZAb9O6Zfgf@~Qs{R)n{}uk5%bk^X{adFm8WBEJxw>7-I`j)d zr+UD|F?p=ZTawRoIpvU8mE!lZJ-5ZlRKmF+LPfJ#Hw6^?@HEhi=LqZoi)cii`k?bl zfd$vMENyDELw3E9Xh(TB`%?qk*@Bj@T4VGx6a~+kCbVRg847P7ao$o{|1sWVaJU@q zg`N?xD?+P-#t~*x)O)Ak`|!|R?WZd`Ph{4nD!`tXGazj9hZXPtVKLXJ+zZHOL~Pul zffy8WL?fy2pnP^2K@LA`zI=)d20|B#l2efvp{Sm@+5BN#oKCi=CLcaPXz-Zx08Yf% zhK~y}6`a*ot?m2z`t2vJa>t~-))U7dNlsV%VUI60exw5YnIRLW-2A>|k-m&4}1+*9P*R*mO5+bqH%);~Ob8rvS@OnC<93fLX z2#yd?Wd2CdBM;P!IT|@EQS3ytGWM_3L3}(S1X~EJCCL-Ra`w_vdjKaK~d5`p<<`*IP;jYkI#oXdEQ z252Ch4xK`gu*mBQQ=*U9?R4!J+iFT>r*v_$nuJ6|90gh~$SHQGlImj=m7(0-0sBZo z0nFUHG0ykQmPLxm>9hOqg(A_7RlKsibDMPK(s8}~wk*DmcibO2;Ca9`l5d;AC}-m7 zXYx12iz}ZsIjPyP;*-qA9V)XEDb}5tXDC+K5O(yi%c2H#I&0(}C%`lZN=5I*v5tk{ zq@=}DEd>Q5a^!VU6 z0eT2jyCZ&ojU=@mGw1so=xNLbUn6loe{qang#Po;!@(3vKuz`p0^*P8nb&DseXr*U zuX6)=oxFMG@Lg{ISdRb>=>Q*U)irTD1X+Cy!1sp>;qUc`qak49(Aw(nO>c~V9yFD%jd=4P}5q0u!)CbWr8*4lJ-$Z>dtO+hj{U$&*)Kd!x2=z%@7XT@+f84gd zg(*Nj^o^k%z>a?-|H29n$kZ`5y_kUQs3^cP68;G?D&F77)|NK9|FrfOm=9r{3fGx_ zV*t!iL?9rfCz$qxe_{TkEc|ycBskn2Y%_`d?M?|3?0WIUeflZys)rGXPA%#PXv{ z9K8Jt*~ZS=-df+<;&0T4GW)-9G{IT|%SwQwK>=i&5uRAXB>FGZzn(sSgZ{53{`cVp z`xod#qxblY;{O7DICX#Xi$V4$;Dalmd0&D*47rC3{Yfhep920rhFXbrujf7>UBV3E zM}I&_c+yxsLo0m~OI@UcfHr9a{Lb8u__Il7;){KEep`rxU+_hT0I$|o%U z49X`A|Et^l$5v03jUO{~)I4GMN$2?2c6zvjr&6bn@lhI|;QulHb)O2K{+i-vwaLd6 zC(S=o{BwZ+PYeoxhnjVh{6o-B$_(|adK=KHp`)6c7KBxRWqNi~rkBKx;e@*mfRLSo#J&m+@ z%#?-pUzq+i?&9|tp9Ts%W>mrWFO2^lHt>6tPY;qGQx;(TO!7a#r)>^W^E)!(;d)%BS!@o+|&h=Ktq?@acl?F@6rs hU-*xgcL30TeIEe_EK-1g$N+yGfcJL@+K0FH{{dOjytV)U diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..128e321078793f41154613544ab7016878d5617e GIT binary patch literal 42437 zcmagEV~j39x2FBHZQHhcw{6?DZQHhX+qP}nHg?>CvkN2FC`PIKdTjbLAX`34>8PU>@i^KjzbxRdnn zB+RF;R|mYL%{iA-2?6_Z zU7!jh>L$P>>K^$-VC%T%aRpM%5D^BJeMvMC>o3P@>OwN1h1~0HkiLZyal6>nXCpXU zA|vKIC@@Rx`$TDd0~n@W!fuGfSr1hzmQ`yWq~XanN(?kcRyPuEDIz*KgU(MqualOl z+=_@y&<$zZq}dBEC{hsI0*=$S>pvAplua&-aa%mWL7=*UfT}-@@j_Ht;JH{)Oxoc! zj|`hX=uZWvajH_HY&+Ntg%azO8S(7ITriEi#GVB6GiJsJ2C+7>xf*q=k?k2_+wg@< zrLWtr*9~{zL4pZOPDKo&Fx&H&n9Vt@r>SiIdBBeXHR#Rq#3|}BsbNL9i##!W3g;ca zEm7DH7{`Vsa3niI_E2NUfk-vT4h8sbC%9C-4{_w!sF2How@~5F^|w)J+V;{nTjry zxbfjP18pM4NSH|&=62wWSMu~AoJD2JQ(!w~R8;EtEe4LZgdq;0JCPRKNkb z|4+q^#QbsUKV+RG8#wC_iq?c+y3ARM7W?$ReEoBD9JI;1O{Q@ZjEIv~e+#W0gzQOC zE2M}>mKN%qu`E6|=QwD$`i9FsSxe6u5+qSJW|>C&z4+i$vZ*ifeOMIEz{9N(<@5%u zHU(khTf6NFod9N?L(#J){c%UN;%WFb+}H1JZP5`ozj1t^Gl%%AwRx2^!BwchFMwJO z>j)lHhpy{d+9m>}bSXMT)464g=JY@mXYtLEI$bDe&c&;-Dr3RTowZjsf;4~J+&bvd zB1hVO=E=Zxsz%`V5?RzT*3ytwP6Z$0Uv5UMyA<_SNTHD_@?3pLboUCx?J;{b$ncaB zAJ3p-s%jf&eXSSD>GUQOPG=S0gmP@1yswUphw6AS)>U~w;xifo9~;?!_}3-;uC;$T zAozDC)|axyy*?gEc^&DOlNL>`K@bIml~+c3U}w`s}ngJPJQQra!J< z7EYZxt$HoA3lI&vO}PD0z~bZL1|~dhlY%zGIDzb4}z5eUA zEN3yJ9;b@VLUS`u02WueBo_1QA@c(2ZeZJf*tKG=@V|2)TO?2B$NBAs84~~?vH$>3 z|2qeaZ7fagT%1jv+)SPRI}6Y_&$Rw){U2Fy&1>zv#hJACO;bM^Uf_ndGCO0pl2+-X z%BMZyvDV<+%F@xXL6?<8m}wlfoTS3jS@C(hh5#6jZu}={b~{u~9Vr3~m%q1TL0C6@ zwbzIKXqA{=75V1ydHSq>H?d&A=%G&68j5=DyBg)wLv^f`W=q1{FrF^0IfUS$J_Xwq z<=T=EnlbUBIQc+{;y=S@df3%%*QQCdYx?-voihLYUE5P9m6%>r&CwFlHzqEAMh2L~ z31d2})$Ev$J=3H*mVf-L){#`4I{mA?vCoqa_O+>EN`eyr$yjXB*urzD89S0)85V~n z=S%uP2CFZ9No!DFI&-N;Utg4U9q!zM4;(x10W2PX2h~IgeXd$V);1bWVrA+VHP$VF z+n?{$tJ>cX8ptEojgLonWkt8>(qY@yBBMNPr`l$ueF{lV=hQE)#!8YnlT`_jplajf zQlg8tY5}gk`YX%kpFkIUEn&1dxGa8{mUo43n>%wto6Ua?5zp*hexfvE+w{20qx7-R3IO(1)UnqZec6Dk7!00=$_SwuKdc?*pmp5OB%}Kr zv6TmG>6YHs)%{#~`|UjQ@#gmS$=c8D$=2}+OhqRzU9_oVm@irP zsN_g;aTJxwYTL3Slh>C6#M>`fw_0C^;9Ib>vpcP-JXIsA4{p(4&Q`pJg2L&3ooU(@ zy{x^sTs-?em-NeAb%9pzugD~fRr$*)X0S2&bS2ET(U9x4>H=4m;JWum_LUPRZFNik zHX6}j%|cD%t0m5{=nw?r;N#&Dp-gJwTHE0Y-@@Bf!5Ne=(Rl zYI2$Gm&|5BV%Fi}&w&iwaONeq|7koGj)q44#WL6*@T86w3%f=}-tS1_R&bZt@nH(l zJSZ3Ru6p-KU1_o*3lxn4UgxSY;~(UgoDWy8B#a-%nb&f%hzM`om9$45ZxX^fAC7kg zXI6%}5hsTrt(|F#3 zwK}d2^Q6$B8)#-72?Xa?<8fPw1X!qFFsF|kW~lKy?JsC6rHqY}RQ{1cHCw5K7XmK5 zYA$c2j3}k57pC#Iif5=o-b&x7PPovjt@26_f)w^|o_Rs6@`Od7l-oSD*0g1-7X>$v z+W+PA`0=_~YV}GWb8Yj;gU$Y7BJ#52ksF7sBoLrQ&3fnMHE}Z@s)%(yjajE!ZQ#_8 z(rx*MhT3MeI1HHO;J_%0$>jzt?Q-Ydp1wN$BrB;$9B)`o*~O429jfm?(L=N&ktr2aQx!?dT` zN`c*(d{Ab#iLRPYefN)&wxXgHSRMaEnJmz-g&`C)Y&1(<8PUKU(-PH2fJD~cEw1!t zW>K~;Dfl1v+nj^IfGkscH(8psqxaTPk8^a^EA9-BT||(r`T>Jf)heQ{ydW&okScM8 za-?&J0fuIA+zS&=;kPjb3A@d6Y&vDZ=wT$yso*ph4@HCQI57Q5YC%Ni|eq^0{9ar>J1HzAV7w^Ezy`tt?kJq zEXP8M{&Y=AzAJ#dkyNAHx*P2zzTrQ{k;g8Ld3#t;85bbqBN-4;9P=+pU^+`iM)4}<}I#IYg z3GWQ*EF|oN&Ot5rG93anmX(i9m<%SkJ86qCiLIS&D71WVc_4LE#+y=3{D=Gi|nPm~W$T_IPrnphZ4+w{Wev!O(ZEB;u zkM35Kn{VW={4TT*xPQk}Itl~Bub@0BdQs=DTUKq85<+;A8nKB+2QR}l<8dV2rk~(s zT0ZJ~d8@Ad;a#^jT|tegDXlY_&3N5Z#D^Hzp_BmE=MGA+kUx$KcHsyHP1IbqHK7JgoShdU6m-E}MsvS$ zNTIt&g(*D-Z25^yUohuR52^o5JBY09%uf19qN63s&y?+q{{TqtkJ7e`qR`fE$fL)d z1n-&!y7JYwPN;@sZM4;h`kkaC4sfddJLtah&D3pyNUFJKTAOx@N@9;p*~@f!%nHKu$FbTo8!X&$u+$eJ#JN$=>q)*w3i0-${nfgS zXR|?wCkhQqVR5hf(;QezAMNfsVjCbY5FdYi4-?WY>cW@dvj?iC@D` zZ+MiT##BD*afvwl)4ax5N@6EK3IX=}k-0g5i6$X_rFPL)VjED;M*6Qu-lmW$%8ZZ# zsh73)OtXEl9^!=xtPd=?p|NWb1S0(6?F^mUQBr12{X)%0+9x|biPaK(;vgnoBiVM; ze%R>-UsB8C(3t6a`4Tyo#8E!BLl#3D>+&PM-$05ze8vRl+sG}ReUNiR9_Oz5e9fp> zX!L$mid_{R5n+~FNs1!gt|QvwRkwCp&HnBOK>xQbDfcWCGqZKR6&D41j}Hygjm?wfZS|IN9=vxOhYz17fo z29W+`mmvA^;A{VoFEVd6(2W(geLm)&Lnmf85wVr<*aLYZ!+xPm`dvljTa3x(=Nm?5 zumCPC6l+wz9KO;Ex{3y~o=dW3>22$%j?6P}3q#0;`4mR@Axv1=vIO&_2WUqvPKmzd zI1$u9wTTVm8tK@VL?*~<8-U?=0t;GMjVSWV;jrY^>L1X5Nr!M4=w<;kAaZ)3RDusa z?c_D_(6cTUHI9dxX50dqO%|(-Imb}yzfwBE%3qSUV{NXYaYaSp*Vo??zmKan2rCb4 z9((;*ITH1+FN{pW+BAHC*{32uu!5mXPN%bjU+!(!`2C;M_5&Uvs)d1yiuqtpXXO*U zXP||N)dl)W`U=cZIm$Ry#9F-<$1-yD$%?k-lz93c_=8xFJzrX6jHUUS7vBWM_$nr7 z?~mD7CBwvb3#t&p@FD^?CSL+lJ9i_Oy>J>iJVPS$9%ZrU-uN7mQyQALO25$Ow9-ht^r&{C@*{B}1j7+k?ngnC2)+>esdc zeB4nK9N5jwLq3u^H87?Kkd3clJP;y_Jj=s$h=)7$cx#Ht$>=524MGI>ZTVlJg6~UI zY_8n~_PE^@mBf(tV{fdp1Tmw=*=vP3fJs&j4_nz5%U?ksk1$$n z;ZZ0QB_mf0^;Un<&r8b%qP_tTbSZP@>1y^X3K~V1_1FLUb{k_FrqY!g(vD)y(Go4< z&j;=|>U`2LJiDBMu8?z%&YPdC?U^;SmX-qJ(ZZyCCpfpyJ%g!=J%+%_FEul)%;2zI zVqmRo4W4CJtPY<~GS;ML?`V|n7SzMn<>Sp%kt{{#cG4^pIckO<;0U5uaY73nj692G zSS5y;Fzklk1fd=p=q?qFD5-eez&3DL(fvA#imuNiGylraa+w@5T(&X$A0=~Kyq`R& zd6OPxS1t9GaI~#0CA$wm^-y|qmcZUG$KqqY_EzXVkQTqKyBRL_7O-yAyc~^jZBed~ z`gzP<@-Lce&Nx;G;K6<();!`!B0phyP8wo@6q1dBcg*4+Bg?D9Uv3aCP8jQM??kX( z8nRj8vCG%k`e4l|Rm5n|4Xi`bD5>NcNMimi_x7Tqg>a*(q`faQfM)lC@O@Xv1z8J6 zl85~y9jVN8@J(Zb=YM;T|NVS@IIBCT8_1Xi5q(R6=Z=@>+xE-O)9dkeck)cr%j?dL zI~<(n2P*4{&Z@S-FOe)f`bp!%%k3B0X-%3_E-OE>Fbh!fN@GkBh%z0=1Yhv=K}}yB z$c6A&L#Z^J3uxhYJw)VY61|(ILJ)YBaSN3JZ5mHcVt<3fRTYy7 zzTncV1O)Zk($7e37h&~ym-*-n5B(!WJ-p!yFv%B2=Cu(=sy-TOqQLW7gBjv^gMzJg z{T4P#?Azlc^bxTCYmyz?*+$W}6i>0ot;83WP^1-rwB9~xUm^vNU7K|*i(7D-eMxtw9(-WwMhj{TCSp=r9=8y07I|I>czux3|U?y9R zl4f0+|F0#=7TNr2XR(LNSgtQ8HfpP?iNukS!2m zL7HeS+YKYI480WkrN+zj3X|cGWBYleiD(&%DZxI8NDrPFGo*b>SSB{i&N$>Ux_%P> zj2n{4ARao!V*am0#{AmS7Q$NkMc2j5%*573y*Y8tqzNX!5AVD9Jhh)FlwUI&zuHbw zY~@D0WiMxpa)xSLp@W6Gc5K;B+fs5gjH}o*JJ0=wTb1k@;-=a|5EneN~c zUZmcYg};&25S~#V+&=|e?VYXUt6#w2z#=_;X<6HxmIZEK&kRPRgv`Ss#HgT%$!ks5 zo)q*Eoew83n3;_3C=aTruTBI(iS&9pfe$5t)YjIvJ?z&U&Ap3H(Fx{6Iaya; zp4?ntwpOpi*`W~_*yyd#fMOf1xn$Y=HmBdqiMkr0A(j4CviYS4AI}A#>(9)<>z^lC z>qxM2oMTCS-5Fxw&7SpZg*(0AYJri4BXSsyp(2=YKf|Dv$cVxVeAGROIOcN10fZnY ziffYYvtAD5ngom?$7^K7mI@0_dY4$!4IXh5_(eW9dcy?* zHkZtl0H-;pPp)YHcD7eT%0h0Icx$)$VV9r`mvCQyXiwgGecGn(d>qny8Fj>$XQGLo zojU!3$`!8J{&BNszw2L8yu%lW}g zBtBsrZ?q66>7Ie#`1@xMKY3^G8$Y@T?HC@u3=vw}EkdKPuq)Kw zQ9HFQSfI`dQ!#d zu06-gXT`gQZxipkNgO%aMNRB&_QmbB#V*A@;%LN=!nzk+-6A7tb6MUJr!!$N%4x@L zNXo>woxm7nR$3a>-k9&vzdc1Ka)l^GPi^84-P9m@u^AWdg(=u-bnNq}ET&nO@7*^y zGWGJ_DUz(9NxrXYlLzboko~)ED?#-8FaiNIah^w_LVerd3>n_;-9Q53(Sj&do%K~9 zJsY#_mMt!m2!9=&e?{;VRx$S7AvR}RE7VdeHfxVj|D2>{-{j$Ro`(m#ukBw+v*ZfL z|B!uuaPr9?Tf2hM&_~+~93)A3jmwIkP8W6PlSJY|VT#$Gy^1M2cUJdbjGMXGZQa|E zL&n>?Kkiy+b(cS=?j4?7<&SUpQY0q~zCF{ELEN_mtqH^Ec`I@*xUW_ek?rXAtdd80 zYeMNuInCdfkCXSxBSN@nI9}?O;UC`#eT{020by-}yjA7V=A&S=`fd3i4x$RX0XM!f z2c6}VbE-^w?PdJLp-e*%buLZ#9F+fEUc!K%Ux7!5T7E7b8~6?C|D6b2o85YTy~=;l zg}DG)-1_}!V3ygXgYK3 zz9DgB)pR0tTL||owE}gN>CToawE?P}IEIRYTJ0RHUSz6}8Pi(~kyVD?C9=Y@>?7Y- zrb%)9U)E;A9xftid^sx}=_zq~eJlI|VqyOu*g%4+6T0d(M0$)nuU_Z?KYC$L4_5** zTHia@H`={KULR2q2OyyvjeyY6z`49>=*P14ysEuzXmMKIx3P&thL|m?`{{s*h}BCj z=CuXxkwF602)r~-QSMVM&WgvP!J)+aKR=ukA`5U)+hT1khnNK|2DvcxdB zwUXrD6|CP^QZt^}6SaGv);O+lOD2*IMnh)TbZU~f3>zA5K{+pcCLj0NJ8Urjih3c4 zZ5wGX`D$L|j?AH*?jc@(2in=n&i&1h+~!}Yt)AK%C#pW*n3JnKk`*Na`G^)}@w!>~ zEJ`JUio~CIGS2*s78fC&5`cMQHoq~~i4XfPMK0H&KJ{}4f;+6N$dGb-0dw$hoX>4AV}^Zaa5U z@`!U|fMO5#x40jk?Q&jOjFo_p+7xS+MvEGf%Z+IwH>8XGjT5(Z`OF*5h(qWyc z6SREUb6M3-$9euwGQQ8HBk&c0Yv_mI9nSy1zLYTfi|BE!z7xGAhUmYh&!j_Og`r(cHc!lXt~(`2Xi3z=mwKozO=9g_rQ6UIbrkY- zoBN$Ww(+pju@){=Na4-;x9@!h_5(a^kB&$kEIk|cr@Ke<=a_V2WJL5q46R0>9<52i zV1f+Wz+s3C<%`5>9E=5=HA_r#XsuX_TK1P?TVwecF@shDCbO7das!Kef&n{o1UMyx zLuM!*=IvnQl_81LxTYp$c0bEfyP)3-qOTE7YBW~9V!1MGrP9W>%CQAlk3&k^f~*YL zfs7N(Y>+mfCKkD7M0H*wsHBFno~DJx=TqzC_1?6}`<;G9bw{D&o~@%wHi@)gQQQ6- zg2pj^K&EET(C$nV_(>s5-GQUP4Pslc4asUoWVnV6G(gq6+aUViZr6fFp;#tLUM>g- zOP-OGt7l8#WZG6#@gEeUnTJo$EDwS6onzm;^7|?rj&{9!0u(3Y>(-C7h5iQ6j=+Uq8stvPN zNT+=EJP)vUIjX}eBYe387{vz;6+m%h%iuWBQ&GAeo=e1iUd15H89!Z>rA}Qe-=RTq&NvcM zR^MOp*K^Wz=`Tz(pr=W$!XM4>2VwjOstMxAbNGT^dB>UEl50@LATckD*Wtu6>;Ft& zQUj26+5b3U_WxP`+XB zG;~ySQnB&s*q6`~`V@18n3_xB3q3T1RM8VHgp zF#pG+;I<*%OBD_P02TuPF#o%0SsA(+>KohJ*q9o-SlZh;|38j*tz&Dy#fkiLqd(x> zC!xJ2(fN6RXoEg(w$t^L#2T>!hQ1?IV#}pQBh9BMcS!Tnw>Nz!TC9glA-6@V03vGZ zWX73;#g{`=!eh+f=P^LO!75h=zrM9lh4DL_CfBxgonyMBN~}SD%q(rn6L#A~3vu|Q zbqvN>zxI_ibW4hqOgyW^F4>+U)JSZMab*-o*Vw{(Gf7KAYm{h;@l(&MvX%<%g1%9RV!8>ILE@7>2rWD_-^ZV4mKY=c-_eN3> zc4rU7ReFg2cBb=CI6eU0fLjMM@R(D#!q}jbwBkSPrg_ZEeu~@|5Ldu5EAF7}0!H3* zV9kQv`7rZFGTkb^H#clVD=?T!Sr`haP+swi(W5oRdNIvI!^((utmx$)+TgP{!npdY zhw#9&WZb`v!ILW{eOGKX0mexeTkF{f+;w_Lon{bCDVI(|T1E&8Cz)|QqZYb`C-kPU zpZGv=enUJJ?9_Sy1QlQaWG*!-yH8FS9~1wCIjm+7An>S|26AU@b=n~D&DXmj>SVlqLsK~OcptN|m*I+5CJQ34SB zlK#oSjQFMYh(duVlT~A|xrJgTPQ_2&EYb#vnWuZ@&Eh_`HdC=dzYrws5+9;SEOP1(RwWH)DF;%( zbP6z0-hjW%wjAgI(G+>I>MZa=-{os{_oGdF)-(jb#UM7Y01l`x1A{9L7c0_GYgFqW2IUpr9@PU;;Pmi?dNs>5t$P>zJW-{YkA(9$eCDW;!a^vO}B5$c?|7bv`tN~2{0Hg$F7`u7)B@vmysjgZSTznLR8qV_oj zu*G~BSX77R)L{s{x$0n@cePZ+Od<wr5Ixb8)(cT=A|cl5 z4y||FYu@*?p9wLL0`X*hl_|iT6Ya_boE{_GKwxJh0Ahs{vO)+-*)kp-N*ymghV6Q(Inw&h7&@X%Ib z-`yG)e-%hRx^4YnU<+wA1iwGis~OxL)Lf))-ZLb^OPNmMT!-PYtp}@R_S%$zqKS4xvoJ4wRPAfv&ph@p^y$$C$tr zQ)%Z6d5B3qlf2Z3Y=|-sE6o%KSBZ3(+Fw-wUFySk(bQBCf8h~_xw-gRpXiymImpYG zyu5Y#jeY>+P+Ol}aT}IhisH6F2q+O|)2I1nO~w6Jy;~6S`NguUpIr)m|GZdebkS5| z$#F$DQU_ZILAo!M0VEym{O?+Sg_jQsWv1;&&fi~g=12-NZqTbCgbPk9jw1wl`8bGp zV?R)6G`Ebu)V$vsJI92FU)-JM!nq*P+-dcKsJH=Is_Ms6li(G$Bap!#AL<5))HWhu zEtFGt`E2-K28Se1Or@1#{|+2=YatMJJ1vFtk!%pK1D3%}%vt~rL%E1ysOkw$uNG}V zpj&C=EwThN0IvNPJkdM2Im$xHYBM6J3XFq+$~LiaR9@E~(5_Y%#RvKmT?3mhhI$b{ zOk>DoO#e#iM->QMK`L4s4O$H^n#U1YwwCK1g>-*jNPItQ$XX=pXto4CQ_l6vc2j8; z0VoYsR&qMhG9-3bZC-K`=>Hgh0-@rNIUOv8hZTuB)=5$6PyZxF`6CWL@>|={biU$p zpc$-C-!2}2VLMdJVUk{!8!~ggPKNEu-2G6#^}lo;+JazZ19dH z9<3~wVFeOL3EUTG3J?HB(I__?I$c?8wCof6=^=-si|!l(TzEPBT!g&5`wrC0%6H0- zP&(4NmwHq0ZH}*;5hOR4lDOVr-1+S%>%})DjI9mes_6{F&vw82W3NPAk%lQyy5cO| z?|QYehSG&BuJW%RuPmyZ@Gbe6x3OwTg9zl9!7T`r|O+`m-X3*$&p)c#mxt@^5{Bi0HBCh#o)37K|ghtlDhH zLB1{&tiwFMV~DG19s#y2EI7#=h>&E>ps?_x-K}rYU|5r`F7k459d<$G9Y>8cK^@=8 z6qy;%Yc3>pMasI3Hd*h1+@RQfZNJ8j+*ZQ9<*GGjDV)yBR6*))=a)9_=^8QpUuyA5 zLHMbb#@}AojI?J9^+R2dp5nk4?Guh`oa|5zRV(fEuD6ZC;}sK}NxkdlnGy!SRvQ2N z$2;VVc7g)fQTe7oOY*&`gHflqAMuwmilGJc5~jorkk8w*=e@Zrg0t=)Yucz8gb-h- zA&zC_w+;HPF?qVWFtFkfyL1Eo@kF=gvh(51V2bh!-ML)E*mNzqGBp9UX(C=D-J90QkcX-hp@y~zH&NQvNxEM*FF8HD zJ1YwO-lGTxx{d7oj%87sVP%BRM1$Xx zQy=7x?k8KuMcrB{?Mwlas&iMcJIU*MtZAGDm?7nxxBNmOPo-T zsk@zZ8*l{KiC>%oZs{(bJDcalb519;SD>#JkWk9cA{gYlx9B0xpRIFd6zpdmv(q{Y zti?!gN!y#3tbp6;Vuz_gRY`*jIq7DWo0&Mr)OC@?*2=UFZ9V$!X8vQ8t@VL?GTX1Q z^3&+0?v=a>w0ba^QAQa*?%E`D*pDGEJ}h^uS+mon4`a>WdzlgNGw{sUMt2-#c}kSj z_yTEm6!cLa7XRAQpUv;s_F0jaNzerDmOjLVPPhEYP_R6L2s|h&7Ft@Jwuz|zbFb&m z;mv0}=FFrRU(C_ROMA9Ktytw01CxR^bu>WH3wATE6i)>X_DiaS$-Bm?W?(CJD(ieT zQrjVU8-@;>lRiI#+b}22D|+6SbE9M(q_y5&SI5ulb-n}sB&z+T5!6LYsVGGI4H6XL za5PIPM2#b6?(}Z?eLkF?f=2ai)y$XR#%1eoM|dZI+~1Ts|F}giL_mQUul<`9SjnQ2QFYpVoD|era2t3sJ_=ar1}O zj*T0^hZ$Q3KXzAV*W2aujGn#X&i{Y@mj25f|A#O_)usvns5t`tk30Uq#x!$77gKjb z&;K6L{x9zMKOI}=EwQAZS$!sH{gohiu&KBv_bF=M1hP%V!nRaO?mu-fxX8lD=#Yc} z%FOJCpO>q2PY?-esT}E>p+8iSpk`iPUY*`~H+c&v4X-ZpGLM`Sry!0jUOlh+Z^cX3 zZgUSFsnm*R#1mGkla_Qh$zDVc6XI+&G#=E&Z*sem_v}~k(!wdq5pJh zuSD}&_QG;!wKd6|k}9Q^VKk#nne~^AvL>wtNS0y9dcD2LE49+R6fz@K(W@G>PD?I9 zS^stD0IwZV&fTb}Q823HMAO!ZR+PRwRWxE|?nxF;dde6W2jQW1%Oq>3)_V6JDOQ-& z7y+rouc-L!U@*MA{_bAD9z&a{!Vl9Bb}3Y7)s8geRLPK|N!17^#}{r?>ZMXgi8RtW z85pFJ5t^cvKut2&dQRB3C=>&EgA=HwEYRadYiucOQpt6dQPLVFWHNf5|zxrYf z$5lw~S zAM?!pr4GoHqX6Mr2;zCbAyT1}`-~01n{dzaH(1vlgaI!Lj!2dy~$5kjUJ~XAs%Jm0FZ(AUIX{hL+4Wn#q zhz3`@&8Y|%FE6i`i|2!{6JM_fj4ykqV2UGmia_5FL)YKm2@6j*Cl5Yei~{3q_-PeG z*N(rlbKqUA5`Ncxt@83k)IFBp>O-+s3O2pVO9i`IIN0}h1dSg9vPK_2AXJR>9skJ8{L>O(bPLM zJ1QgJx^NQgONvdU2p@`B0Ld11L~H=S0i~#@A|l7Kw1W{-c%gpLEnp8E{J9D}HeR@CgX}CcZeEGiM0wGrZ^SbHhfz1`B!xo-r z+ITf#!goNVcNseO;`zBnXU|-p9p_J=Zu&fV0*ZuvRn|8U%T(3JYc95u zV9PfO#mc~Ujss$#lMpd>nAjt;KH9=HFRUJzS9G}CP&M#d4G1hQkWoAN9=mx-D?9_7KGexvX0(EDWywpcSw_W(SDD%w-i`^?SCW+iR!g_wTP z8KpmU`9&aC2Ma5A4y-Y=;kSqcxIm2&yE?ktI(|16{~|1umKGmsAx&$k(SY`W6x9M1 zX`qR6%7IMOgS4P`l#apdd03#w_LqE<0)JGrN-wOBI&}}*M>uS$?TORWkw>X zTjrutUBX&IumET%S%^6lfv4k(cuCa@Jm6>qm5i4IVrZWx1yUpq2!S)tcMc}XU*ANn z#sO|be&OBR*3L`llq&V^i>)OYlP zkJi#ZolW~h%N*JAcsw72YzjP8FY)**motgf?n-3Ak!rl zL^Blo&iLW96=Ow$7RkbN^7eY|alSx-!nQ(C5taf~3u6Lz?(?CNB?h|}fq~a!L+Tv% z?4=q)0rCwq1J-u{sVV)ccJcX($;K=a)F01ZmJggUUt#3%iQ6cf2kll#j2k zFXz|4u?NBcCTUIkOlkXQ>9-(9;RWmtZxFNZ$LRljQq%jen!{1CWaAQs-q8?7z9z2Z zB^II8XG8HKk9mO$Jmu1|XcPu~+$$+7j4T-4-P>X9kX`oH4T%R(m$0Ouc#Y!70d+zr zRsbBv1x;1UEt#bSh?83|yKanA5T6K%1@sAh1D9$7!WQvu74k=6SbhZd(an2_!1{~p zjSc8ykF+L+_?CR?RpGbmWdb!}R5%enSr;5sDJ(sp1N{&*U(FWP)ECavjt`ddaC#3w zTm-OXmq3XI0#CSOvg*oPxja;7eC0ChPdRzGq0dS{$x@*mM9V2ulywsLyR}03N_Thj&~?wYHC*_M>oAfzhNsG|VQ& zCgDmet7-&=+oe&hG%3mA^$%RJg4>f~ljZ{gL5TkmL6sah$9E%AI-!UFwe=fwh0ZOd^_O>An<4upV_eQ@RR#2g1x)h3}9iY89 zz<1L~R)bL}x!H2fk-`iLxiavHfRlmzI}#bii1b2Lfu(8Gi)~U|+P*dtV2NO%y|KDx zUp|?H6Qu#%I+f&KG2YFOnOb#?Dglw9{uQx(8_~e#ycJOnTjhzQHj2HE*;=fOjM1Na zkCV2ZaTO+zXB|#1#J14dn2a16ODv-i4AT?N#-$w6yn?UcN}U{fxFf{_^;9(7t6IJzXZg-J2Er=4D8gF1;E zgQJBKaj}6`$dMW4$_ETYB|?MlatNBdx`1Xu%hIrdL4O7VkZTQQ9n`}zqGaScC{B3d zt(O@p1#=g|q5zkxG*X5wTqFvId8wetv_?J#g{Fd4BiZjunXo~7Rw{Wi=n{QuJ?UU7 zga*?xJXvKKeGnNY2j0^Jhk*NUu`1^Ij}MNOqZzcSTAH;B)u#4l!L6yO4quoK@fWgF=TpKU-=s%d026H2~G$?`kSV7-Rf(4Gi|1?-ucRbrh$G$FC8 zwOG>rzrt%qx0`%!s3PuS;xZh7LH{(_p{;w zZ=d_(1)Sf=WOn#?RfDx+4tEVDwoTKG>V9CgeY+>3h{;l28+r30cA8+V8kQLmlgPZD zlAshzG2USjG>S4REflEdKFR7hc(cWvcx8g{?Yx5C%BFZ>%00q!6v(ngANUhfT);7OY5! z*I*6mlUx?5>XRY(8_!W)rh7-vbYIrU?bNZH?5As-; zAO}#AG`UlKofccd1JEw3I0WJ=JdssR0b7k zw7~T_K-UZ*VNwDZo2xqLTGQ$}ACk;mmu9qOi8Cb^EZX)E{ekIX;eqVW5pCwY+r+qa zN5{0UWs-VWO~kokpG9$C=ufv5A!-@*+Bs=^#4G3?5r83U!(1GFpy#LuI%nYTT zVd~TYt;&==K?A(19Yb%6(5#3mRj`x{Q7`7ynHHC(eQ(7AGTgK6{5bUk#nFtW;A1nx zFT{I3y`=S-D?Hr~Mrc0Xa-N=tw~3~`J!rNEFnCg;+U^xep}IEn?;G#~B)HigHtRNK zY1&kENdbAUQErIUWHk?TB#OY&06?-aBPf5VOfJ|imBuAzlT5V7rCm^Op&)76p=>py zxx?ZHfIE~|Ni^NiXak3EV3qDYHE)f9zOBRR4}&C?sgwsBJ357&SS4sK=R;ixAN97d zNJl_7jn2MrQL+u5h&dnGgf+s$3Sr{GQnGRDLdeRD)E7r$|C>7ZY3T$T7mU{WwN0 zSrT>xq|q1eTa*0oD0zB@EM>pfc#gg%2(x7=|Zn?5zyD+H(z<3 zsbij*ScK!An2WV^Vu$9zsRFLx5)QY|$uvBAOL?mZVVfWW{}E3S4WL6PH><5Wsk6la zSxVnm(QSo#-8$mq|MryAEKzC92#ABvzMw}Drp;Yy^-9%c0?}fnV?Ko_1X{#m?3$0A z)r~!hWB+uM$;0WS*ld+DP8|dHIdgL`!@W)CZq5Fgv!f@3&tzNGa5o;KWP2t|r!7sE zhlcQbI{vvMoyF0(nkP#61)hHc0=LW4uTO@j$J>k3{OZk+YSTMQ({2u7#O1t3k=k+k z8*9>y0UlkTUPyaBzi|j(7poi3bhX-Pt{L0N==>r&8=k!yUKutqTVhA>&Xw3PcAoTa zdZcca>*b=eD49u;;{g|gwKjG&C(-^6>GxWbc3`QE%^MB2+TKQKzZs9|I{yQUw2R5+ zcJ0}(cjD_oL+=Y>sCCZkK~o0{X&58P;9Qh|xLJQY9NiAf|dOy#SoGWNzj&~9bc1PqUcA$U1xOnET>h&eA3EIARg zM}SGqC9=XJhg3lYx+!N=@q(uDBAPb_-CEw1FN_D61*Y_{j;dm2tJ!NRxg7MH{2i0= zyL!Ui^muyIT@Xks3c|88g?YokGXYDsp$fmO7Y`_vQIw>L%mkx+7U)LX#-YTMcP=xF zsj^(5VCf*fq;|&MC}+^PP(X=MPTsQE9NFd)XeI`>`e$m#hT4Pcg&i(554+mapab0f zIEBr3QbmAjwuS&grd-n`JL&5i(?{5;nU>M?b){;{$i2oQH@7VO!j*3{Ws@h+ShSWC%Y!X1mzv^=L z$Gc@nX-{xt`a*U%S{LdD58x)%iJs=gEND)sq3^yAbqogEjNpofuXKv~PKK$><{`06 z9@EXhRo&1--Qkg}+Md$@H-zdJvLq6A16$mZR@=j-T{U&n6gajt(z;32scBx8sw(m< zrlN{$v%k$b`3gjzvAI@E7MUM<<4+8>owWXeZ%yo9(P?UKVtfYXcIisiqRj76B6tME zyTj+n!*m^p3)dGLpC`d4>0@ z&1n+LVC7N}EHGxcF4tauT#p{IjI_lN1$SG~HE$-vZ3Dh6NyI=(-9AzPH$Kbh$XKtj z80J1%z;{`!x8GbAY+2EbF*HJI{cBdFMSmc!9SOW}ruo>tGA#ms&H}ZV2wUO{ufM$+ zMOVk~uD)#P=zbo`vmx$XD~VwzdI&%Ta{EoL>x=00ieMX+QujH$LewR$u19p` z04AVeTV76+iL*sh{Piv%o4O_5IjNOV2z~NrhP64I4IZ20b7%`&QN#2!OBTV7)S!p{ zQiHqNXA}&eo<5i09eu+977MxhLKM%xMrp^I=b|ylp*|w z9loi0X_ROAYzCX!bN!uV4G;D?ZhGkE*n_vmcugE2<3HUI;!*^Yzb46F`)2gM9kypy zeY#N7TkqPnUHd*sb@NWy5q8k053B{#u^bt%c@d^6xl?#q#{88JN;D=0((Cg9yl6na z5|j_G_0}cDJO>PmI6Q2$E%iRAYw~EEYrBtAZ*$JXyCK-p0@gEkOr_ym%+&dYjqYXC z3+K=}-sHqH(8p8V6Yf{lxxSoXsWsKYUv+bvZ$I4M-+yjGJKR1iW!0qNN4&R|5QzUS zc+qei`i7dXrfrMXwctrd8+d9oXved65(^#Z#mh_|e+Kta))eFHI`$kVRguayJ4UPX$NUs^ z6NjR|qN6W%>)Toc9kZ>u?j~QXDew$K4Sjj4nVTvH9**+BHV)w;H#!KzMLdrwo{%|F zHdh(u;@WDiWip_u*Pi}n4|&d#v{t2#CUK;yaVm&M=mKng`Z~|MbU?@DAI{eBkp2vf z_xlwt1fKVzBp?Nl@KkS;BR)_|EZ~VXS%XJ{-DECtcO*9O&<&o_Hi0?pE;%=({3O%y zBViNO_xW@de>umC1y8#zIAImSJ03(yPqNtrXWR&|?MYKTDxlzsb`C8AeqfLEvEJZu zIH0bN@Yg!ipc!w*fo7atbQ;|2$E}l1PD3#@VS4gqcKi3b5}g^~hA`12m8HqXjFH6H zW*}`Mw>=9zNaugr`|{>Cjx5jr^(m^Mi7?=SV5y^LVuIUr4^7cBZIRTHl-&+l0Th9v zSQZGtC_og~>+in#uFRKL0i@)scM#ncN#uRJeCO{mErJDrt!u1!6>(nz($vMv*%^ZL zLw9CaPdZ8h47HzM9Xx-L4HCq8xMC+?e0{LrstbjW8NtBszncmkgioaFSkRrX zsBDV`2ubUmBNI`|4j0&n;2&#Q_n+^boo%(Mrqh49m1MnKwe0lpz6W2!!1O|!1{76` zt@ZGU+{L7Zb9_fv13{O{d$I6VCstOqL^3jIPB-V<5a1;SQfLc@hKBHSxr*X(FlhOaeMATQYt zTFi;IjNlN=xhwPpa4tXf3?SMJ{j5M^1pHOMP`(06D}-z|y*>os2FF7?fMU@{!b6ig zjFn^X0quLDHnQ=s4JszNrHwHIYLR0eRm!VZ38YhHo*c)SLOZqlOPEK9@=stI-` ziF+q4Sxy_z3B8&eT>x(@0CT`yG!)Y8`tq&1`I%Pql4)0}U2jPK zQZ?%_I!ejPp1eGN@$!85?a85~%L9PqGGp(oSC^0eOOY?EZ%#VWWFum9xv@5vUV#T= z>VJBAaymRcIDdJ1JUlu$JKKA9(AwR@;{4?4NdX&CoPqv=+7nBlb+X1k|06$A{JwXW zL$gm{3V}O0=sHNW!O4-}mT4m(nn@iIuwhlCD6)4iAg_DTwOpbx3$#e)4lz288nR)l z{T4KiS|??>$H~Y+VK2j)U@U#P-jTYd~`9LB;!O2Dx<$DN8z6>LZ5mp-)n27;TZ# zBxQBM`&J!j^+!l&=4J(_M!16406MYSL8oN`r&cQliLnDB4Kd&gRqhLvZ$Pz6!K!a( z^Bd9yVWg=*4<8jM-GScuy&I5-xgY`zP_#|b&2{2c?RL^j}0$GkB|R2!T6>( zpgg#!fMF`SN~=f!Vv~A|Lq`~v>iuF`kLwl7ht7qO7vbX8nsOwLL;)|d#(%Vp-aF<> zHz@Oia_syV77{Ghz#&y5@hVW0D^dfJwFrwg9FA;ffIr}2GriQS;m~5Yl3OWwLld|3 zDL6Zue?DcmXQgppFU%PHAwhuC^s+aw&4$%oZBuSwgBG1ryI+7t0S`6);NJ1HJHGhg z2hXVPY6y-Gh=*$aIE2WE;ZHVPyshQXU4NUF2}@eFY(*T`ZU9wI0}j_^<^W?KUlhLGxilmy>N;QIY$uf#rIJbyY1y1sEcb_X^n)__3>whvl;R zpa*qF4AglTZ)(_IzmZLA-Sr+H9HP5~q0{!}TPWd$;`kn+0AK`R81A|8vXZyBSSE+l zx>;?$ufgc=cwsrD0{-!206hkOaUT}nNl?ldW(n78B`!G)1q?S9P{Q=A2q7x#g4TWn zcZ2%PjEw--gLo^w(SQS*{=J3$O;&#Oy7h3{RT%&RZZ|{EZJ{vUH(Lsw-}OOu=I>cz zp#DUX)_i+@`Gs?+f!(=w7BMI8I}{rQt4$sf+AxF?a!?`kCAs5(Q3Dm;^vcd>m-QQP z`YMZ&at0c;oa<;i`hqOE8$fmUF;Zq%Lv;5G<+_%DPr+|-QOHuke<0v?0fVVP`7i{3 z+Ef!wGldEO@WJK}j!U>U-0$*sXrxV3ANZeyF+^K&xV94n@ek@;37Y{QF{RN;kh7Ek zZW``}!S1dk*wzK$R^|kk-WU`xGXsx&vLtt7U{{OL+1c}bn%@ZJHLvoB44YT;W@YU? zBZB;(P%|DAdIX^^#H-#VMN^M7>fA`z@ckSxVz|9keMyqgiVQk%UV<}$u}p@+gXnB= zKU0eo8HCdzk4VYCf{3Q}j*2cEf*L zg>DsA0({mBxN6wnZE;nj?U}4UjTlyx`zPA_-?}o<$eLn`*dCW=qfsbJuPR>C2vEdr zl>-#MxK(@WcNenx;z)Xn&ut)^SP3W^3K0l@Ftkhl;&H(cBmh2=#09a78aR4}w`_R1 z77n%LgwZeT;Ui!A;KcAjV}Qve=%Ng@M(r4gA<(KX4>c>;ZQex>bOSz4I+AASI@5d_ z=P#sI#*xWrljJ$4zjbMynZ?5hE6 zFDNQ1->7;8FA7-}L50$!e4^AG6|vuucy%J%hDv&8Q7+SUIjI6*AsiB{rlYL!VrJeD zjdD`K)v=&D4d5scx`Ee|au)1*4&hc4f!;Z~nI0Svz6-%q!S_O{u#^v#;>GekH_W&{ zFCHVTNW)Ms7&8p4%NSi3%gdG0nZn1>Y4KQE&CS*D6AM_M2ODQ+Cz8Cr{JP?Z#Dy7E z>LKz3`n8}&mJutwvEhp*_!WX`#%LEIT7}fnFE}!6@;vU!Lm4oU6zyxHf?YlP%s^H9 zOQH}_yj?{?Q-zpJ?Rwy`4C!B-P{tXynzw2LzCk`{z1olITB^K7#1s5yYmf(wH&NgC;eZJd+4By?H z9%m!lEZF}Qc{4jd(G1)Fd-2%6D|c$Qb4Rm(gb7rU7Z4RGBvj7wbCftjcelC{e^4|C z5?r_$#9-~Xx zASo~G#-j(8;m$;R@gS+rN2p?o8OH`gK;3zKRX5ps}^yp8MZA;-91sGoZS{9FW=vY zT_Zy2qI5CSyg&rzYNM@ZhSWlEx8EJgR~J=C24S^B9-0Jf5&zNMFBn;*$TTZcUm}Tz z#mC|YuII3WKB+y+Mlx&qM1V-?TfoDbTZ~gD*%W6L#Vp{Fn_Ev34*RSxy~Z@0tTpFJ z?osr0ye!9+!pKLQ?Rk8oTv^BngLtFmw7~p*a#$FIbWDvqb2++dei%cKY#=dXe>8^P zJ{IO%`*WsF+*J7J9nA zzP`m=vA*;Z*`1=qo>HnDV+c1Jr;vd%_hwgRU=Xm52m1ny$Fsq-O1;#t!e5D`}$#fjEG~w;eND^I9h)Gz+rAnpkyD*H7 zVIY|3quIeS859nq1ZjS5n*6e{)y-ACg^u1EC;JHBmaPR#!ip=l&YWA9FV#vM;QMx$BoA& z2W<>`qIk@+*Nr?mn2Of|!@BOZ0zNf@j7~nh2D&!?sf~3{^pw@aQ__=N^0euy_deqo z%ddan2@D_bZ%SCCBszG>zL3Rqj`7$LI*~GHDw9>p&ZK(O=iGs9*|39yt3e)Hj?B zH9Dv9Ij`Tr*$y<}36Kz(r?`jPL}$ScAzBEkWF3%(s8?B55hNvr@oODm^oWJ8B5l_W znlCAMIqA~ir5?HVy|FGTr$WRrOzX>vxG}aQNjeHNl?YJOs&+O}*)~V^xpbQHCvHV> z{kOW8gJ`{b@uAT`Ki(GsE8uYG)kz9TN%BAmmBc8i6ugY=X?${`K_zD7PM3akiZG|v{sy>N~)e0t4W zma??K=?guhJkH+vN1*|2f(Q#J#=Bn0R~o|D5;LRS@;bSxID-tDr^p(EIt3Z2^`CKl zSVtN%BAAiK`HH5phbf#g`*afHtI4a4UO_!}j8|2We zrnK|RKtUtHLVmrXJ74nP5V(95I97&w20(#jS$>p!RczirTw;}M#vU3DdbcsD&IRe) zqoNDb<`Q&H)L^0$)1zg2v!-l8355s6I5l&N0pXpB5p6z}qRsym=%qr!(WEmQ;;+WX zXES;!ZH8uQxu4@~hPX1xry4#CemJ}=SLKv|XS>p7FYYzc6%vIDnhf=f@Zlv(F(EB% zL`eL>Znp@vq8K?XO{7)|R4(D`qAIHxCDZl*t(pA*mMF0nx@}L0%`O&}j&_(Z{NGyi&<(v;Q0f~8oXx-gf+^LWwXt+KYlzPIjFIaWnVge-RPffHeZIB3GkKA z<_WZpSjAd9`V-ufHfEl$N()cSjb&^dQAO36x4jY|nA+q}vnXkOXQU2?1B%MjT|c(` zT}o3(o5*bWC*7%|GK%RR~v>nZ+Wb4tSv2`|L_H1&$Ispkjf#MVaK{N?522=^Hv z^^uDO1iI2u%h;&N!aS|mzCr)zl z^Z&q#dF$(%Qp$s=>D{~r1{P&M*A*^Wn>zjvsO&8-hrs&5!KSCw@JrAhzE!&GzH7Gk z>hWvyzd%h(hs?G^|8%*1g*6F|2Ir>-dqqR*)l#1|5-z+C#JTSjc^6DjsTT%3AGday$OY{3rA21#smc=-m8(tm>9wB7$vWM zT2aSmnA`4VFI-9iVvzz=Ly%?)lHGPY%Zm^jCuh}(OAR$hOImo^7F-5TK!_XD*{5(O z4+};UPx0ssrr?lj)DUuLHSoD^-axWWK&C;tB#TNDhHSy3*NM#_9d?=P{`~ltou8ll z(z6&%xs7XI2s6GM?2o6HH|2~_pgGmt4#n1h%$uVerUI?uHJpr&nc)vJGsPcfcgLYh zlfBk4cvbv0vDkYFS?4+s#X0NC(8?J#=V`HNUUcc2b4AC@*TZ6u^FuG@H3TU@NI$1^ zT~wTOoRvPvELY_$(MUi=nmIItg^T%mY3B2;Y6j`mQnF}H3T%Z%N{en^B0^Mz!<4yu z2{ENl!gNY!st|bs%Bd7uOs5>tIlN@?obh@nV|3#|fz$XbN_CRH*4$1=WGx=KB9m?2 zRu2KmYo$ERFRDv`USK45p%)D1&O{jVNz=r1TOkxqY6KA=0yWuIAZQ`aAh^fw*EP4K zBG>O}ZgDmi-PnE1Ey~8Cn`$+;+|_s*SWRUFZYMr;a+1lI?*(=+)pV>+lZxO4K;j~BAazuh4&^Vb}xTWJww0O3FEBxDz6l z;66R0oEAY_+#B;UDxOUNLpW^Q)clt9qT2&~wobgTVdG^<@isfJIq*4#F^19|>Cmai zk8)~hD}SYPYa36V+|z^ogTrqRoRhxFmlu2J3%n27#4$5Nw;ogj$%5GPy~Uje;fT^I z%;q=tK|^9Bcs1xAPj2DBWPM3cIPh0V?Wf)sOnYd)k_?4KA$m{&RCaUNjMYX-0HuMY za;pqsvcpgk1lHNb2K+B`hvDT-5VxB`Afn@MqkAa8i*++FyibFpCpSsa z5ry1~bl=)cU6~$4%lku4j>L1c!BO}yn(@rlnTLP8?}BS$zk@%C^abk4&fv#+9XKew ziV}Xf(}hW~^}Uioyp9)K_oZg4Z3n#AZ9RPZR+!o7xYpVzQgxoSdX_JtE^rk$e zY)&$?ip!!vk}D2{^Ch6<8uq=FrVUgF=~H#c;1lM2Ap>4()|b@#85Hq0^*#w%`U%r~ zTZR^|htN?*$@}nm2sUQcZ}GvfiQW+wAxKnVsInw-v~A?GY3n&;U!#HInJj~$h6u;m zQda<#lr7y!-l>5RkXGgRsv95qbVx3oFU!z{=lmm{CI~?P9nMm4n6^1Uk^i6G-Bf0n zKkA0-L3%O@u@LPt%GWA;Hp+GsOA0_55XFmR$wFo>*|A4PKzN+3E+pHC3JI4HTaDA9 zPr$xmy0c!m0V3eHM+|22%*gvnNi5OA-yY?CMI;fQ2hvCqh4V)dznUt5mbI;?CPTMp z3qfXhCrM`Lq3UfCLUO2kL)q5d;^$w2udq1gZ6YF|LteN9WUL!5=8N^T%sMh;g%2qtb2tFCe~joH;eIjH#Mlf zxe>d;Ok$QalP?9i#9S_w@fh0HS%=`0N}6K!MP*DeyU8#MnxAEw6-c7FzSA=erWuCM zxpaL&FP@ke?Dzi&&AFBS5UCmV12RN-ilZQ&5fws)=@*V9&_F(f3!0;W8=}AvEy6D2? zLEUZhj~+sZ0ood7DX(Xk-`-p>-0gZc2^=xGwW1KVpWuaHGC)y+Vt?%wd&Rkt9!P`9 zcNarBxkAsSRHKwG*A)nXIOWwc)EgItpxhO3{t{Z#J+pW9#d<|OK8W^>Y+U~95$Q9! z;S-fmCv+kH0lw>d%!@C7Vu3DkO<@DjkF=3dS40mP5UfmdxC9Ai%3nZ33Cx{kAj5v> zlkG4(ZsyAdHp#4LGd7RKvYv-1*nsIY23fMEB!!dt8X?h*yiUo@H99$Gt{)__MaV9^ zf_c3*%coS+v4#TGxO6iEs@P9b1W$qjI*6yBZfru`G;#38q{+jiPJjYY9bqJB!EUUs zXeJ+5`pd(E=TA2#dG)t4)~L7jl9%M<<4bfQceXjC%N?REMLWOa^wmh*pbj7KM954` zi`I*YX|JmXF)+_=z3Re}h3waTHqwD5eh0VNy${{FJMY@Z*)b7&a+~87 z23CT|;*2WB*Kqsi)8B@>&~`oc`_g#Vb1d9B?_$e%)BBuc^uqIsWMqrH(p2u*!Nz3W z+nrNo!}~Bb(%p`(HMP5tvEvhS3y;1yu-6tIhtQw()#KM;5GxrC7{=V?9)?e>i~qH+ zOUygT0Lj_YW92+RzMrf_fSj{05|ATAQn2Cn10OU9pzYdZbXYx+JHD4}3MW_}(0M`^I|%Tck6G>DXLTPYA502myz3q-?+fDhcr$YWvPHESK) z(02hN$N|ik;e0^=L}dda<}Q~O_s@J@lIvXGYmOZtR83{IsnLjPHeo$pft&e-!!cb+ zmK_ln$w|bIX&n9d?t0{*&zB-|_ls)qW?lzXtSb8lE1i=1@%h z19p1w<@1C6^TU(lE%;$5jO_5&d*AL2pYI(%8}2{fJ3~=!vys(1Zrd5*?{={CT?4NZ z=3|S`iznoPNbazYUpq$l?B)J`1Xs#i{!tHOa4b5ld|nU)67>rabWW38y@%|owY5C@ zfrBuBUhf`SansX~kl3bP`@&Y5X#=zI>`aNv@Rk;kK&X8khg)8n(xmc zDMk{C3I7EnytyY3$Oj0}$kkA`Fl-8bXH^tv;i~r~G{@Pl;GOuku5Ki%^PTybTL1j> z&x^TPKKRRkUG;h(^`h3v)O6wseQRqi&2G%ErK(jctylGQaEjk08dBd%IuMuLO$J#v z^1O*793zd<=-t4TPK#Rm;#L~#tAT%t*5Y?%js6PKHV(TuCluXsGJ*f2cSkuTMPk)_ z_JS7M(1pp|b&Cqd#qHmI#2L$Jf}|#Zkc5AupKT5Gx2_tr;jhiH9`Ig~Joto#1gK`S zf9&e1aydC;JVwfzf=6L&PZH@c382l+|B|CcRn}@6T18}*TB(|WK1h9_x3}{(*1X4b z02_jq^$Eqe!V^<2ZRUD6WM3Eg&B6CT7XL2Ce!f)a{qv5eH0G7asg3z-PjL*?l+>ma z`Om#56*Bx??lLb-m{}&L0k$3fp>@FbQvRCj_3nCDPeR~RfjN|~?;%?oDK6h&QW=ph9{WUn9I1T^&$tMw;@lw~ zAuxwrA&s;I90?r>PiPy6(Dsql1}~4lIR<=+O?m?vR$4XS9h~1xd!V)=wBE9o!1S^I z&EB(vx?%X92y`*N=Y>UH@I5#N57-X*YEb}u=` z+T08F#pr3sI=Ol34QVz9Q>Umo=}(xZ)1PQb(nS7lWY1e5{iSSZ#(iwa8y0`HyzN>M} zm4Y=qJos)qH=Zn&kA&~Gsl9sqI*Rr9lL*1>Rx|{nQGI2vo8hcMmfO;C?T_itoIkb+ zwJ}WBYN)^;Wo7{;|@nP!=WYFx%+W*zF|jQW3RbZ z!{f(gxU%*KD=ox2QlBs9WLjxtrv}cTXqf(e=x8d z77s(8N}&J>{RKw&MWBlI_>;m5vDdq@fXqB_n1T`z*yg z5CYIH%cLY2ta`-3!4zM|S~QHqqnqfDiTpLqkC0NCP^jFtcv~B&Hb+)7HkqHBuDZPu zo?Fl1hFY;DRdcGAb=m=ZwPiV}9?dT=?IA}a)U1XFp0v)8*<-%qWXX_jVBNq2-6Fn4 z+)60RvGnH10r}MVmiz_YHSnW+4|CADl)pd!%Dm4nPEY>%eIl_kK2bvNK{K;64hQ6QI@;%Ybhi^@RO3pIDYYX7V>TB@j6lmyqXcjXhrnr> zrL5P^3HY1d)Q#H922!ha2vksv5%D@?H*-f+CQIOvn-x(2k9-*}&Km@N0lQv-?+-St-!&NRo8S zv_-jCy-mx7aq8si>cv$I=p7QQ>+^YsZn(xBI&w?5T{o`$9Xd%qh}r&PDMt?=*%mPg z;lW7YP0NULXD;Ouoj<&4HKMd0D%lOA)@-%DejU5U(L`iEsil`MW9c+fN3-t=Wqb3# z?S=#|`{;n=l>08wbw+#S>bk#Wo*0)ZCpl{QzU$rVvud?o1Xw|S`;e?4n0ZTVJ+&8* zR%=)eW*eT;5Lc%qPe|@wQp4`sZZB$WtMi&@HF>kT8=BOr>GUxQG}X}&sW6obwm+`Lx^XZ`DLj(Gc%^o-Ir- ztBw#GEY`WqS=nYS{_^qpDF{M`$LGV7FTOt5KOY`G&2EXkEOt{8S1={Dg^2$gE8<^n z@&jEyo=MP$g>k9reYVO-mkUR^K4dAopqr>ADFVBbj zC&%9&oSv)HFPmNzzb!B1Z$p$~2m_d*sGIDmq4rL{Q3_`-p-WGm;f;a#bP8fs)W+^4 zFUZ-!safMY_xKKc3dJU$cp!47>)qH7D9#;KPZGrrpWY=>#*5)rN+4ZPp)6vX>QM`b zaXO0|IB1>S8X5O`z;R*c^>uYrPp5Sg=S<5aLZO#Was{OlpVz`av|t03{TV>?Dh)wOcVe^9S01qgv7pJGrhHSyNsO)GeHl8Ya+vtC~Q_q zF<*b>PbZW1#JSzK$vhrr!hQ7ut)le?<}LI@i1}5I(+*GUmg8tkC6eryIMH<4L!Bxj zRFLdriZ{ca+bnPLvcgRk-Ig`0C5+k`LeR?1dpqhVhT7l&ktMU5o9@5%_J~H}%d>5) zspk1IhS)9lb+>KHgC)q)LeMM0;{;5VdP;UX<8`x|U&F)I2@>An@`3HZ)I~U1WXRBm zzc$nm5>AoZbz>MVSD>0TYs1dDvnY7p?2>+6PS8gbZa!A9c2&>~PwVl@rDs;2>+k`C z(~Q7T&|%)4fM*CefKD4#IldD6-SvF3#w=OL3L@!|p%ZtB&Tim&UMjEx)YM*7rg%MH z_F@tb3%668EZ9Gvf**)*3g0>?p31CdHe0y2-|JcZ6F|gJCKP~LoYc*#o|)-G1LTt@ z?pov};`yn;r;8@c8OD1=j6i4GE`OCnNwt?M5cMq?= z@|aYE?fDb|F=LnDYSP`6~%-6}Qrf0W@i}rUHv%{Fb=%rx`d_MB}mbJ+8K@eT2b`3m-g22MaM;Yik zj=Z7tc0@N;rl}HE2yj!SL8a{sule5TX^?j@eG$1FUGqENaCcldDo#qGDM5w_?~2-N z|ACdLz}>iYTPL38zwiEnA0B4#+#k{LR2G^7$yfkikHAK+G8d`!?c9MA;`Kg`2mfm<+Q?ro!x+o^R3bczR3x z1f_&)GLKh=V-7bn%rv{L_zNDZ!JH!x6L8mpCt%PA;`HfoPR6dSnvH zdF*?{TPKcYhX20vQT_kuEas%zbO0I{tTDeVDie6s6Y}q!g~M1z`xigs_UvMA)B7J> zmn!@t9a`Ot=hFFGoMoNgJhaED!T5s7pYp_0F6ckUahF~Bu!C;6>a32mm2U_SLRY*G zs!2}_I;1ou^?}*LmV7WR?@8;9CE*S~)VxZkAUZDn;%!weGTw$hqnV-}Lv>YJXa*=( ztwvYfbR$BjTws!3ep$^I73MGX)UY;FSJB0gYrN;42tux`BL`!U&~l`vy8}%LK}nqh zip233&ulIp9Pr;;Q9f^T9;fT3eix<=+Gh0z;XpUm@J5ZLVV#rOw*+Bp%VQK;<(Pnk zI_Bh3di`(hmV;dkQVr&PSbV>p%w0=)8*-&xs!;>rrLSjn63$Pp5rbJ2ce~saoW7@EQNEgpH#;jL zDHGSZ>!A-elsDr6SJ0G>l$FhjFL^zCH-9SwA?F0NR@c`x_OqmKR`YpbWXoHgse6w# zWJAWq9zI$vLrsugCNXh%cJ%^xFPBx-D}or#pd+T^V})G?)Zp*vmVr!^%S7VDDNhQGUPQ@oj29$ zhflwp&pYO9+$wFu(XIJoAAVB*eo+$QM?}Ie5Ek2%z{p@3ADqJLay7m(<0{{ja5~8F zGh5Y=nY(^7D@}iuZ~r-;QzdfDSV;;Mj3tq)0k+K!{c>!+?Edg+CtfAD|H$Uj8L*8u zH6jja-MoG6BlTg!EDNp<{Xw>g4;A2Yn@%{609wO@0A5P~p$-Z8CSI?(<`POu{7Ov* zs=Ms6LvAkJ<8ltBq>;_sIS+!Oz!QtV&xUU2|4ciP?h(MgvH#%@Z zim{i|dTf2Y_06`InyJEn<7HKuc7RMw%k^x0WmPf(H_`OTo%u05B1_Z9)o!QLr(17- z141-kteV|dul?t9Qdh(I%&_>i4I+_fo?&C&#`P)@0-{S8;fm?iq%^d9R{x)BDBm;P zLMli@fgghhoUNFktzRtXz}OAnEa&TmjcFmTHE_I6@s-bUae^2$vFRbo5H(7XPXdp{ zxK-|mD&A3^*uO`hQv>z&=Od2x02V+8r_`Ffo>f(`b7@AnGxB#a)w*B8bO6K_JPBOe zGYZ569+}8HzUj~b8KNUfYfc7bxkS^X88LMdGZb)9$Qgr*1~;TxRAajPh%E(D_2R0! z2AiamM*=gc_#?Q8hfbbA6s%`x@tqhY@DAC|JA7aG|XOaV8W&mO^( zJt>zHBV)3&J>=QA6vhO{6((c@+u)Av{L|BfUFCR_Y{%YU2%q}EWkM8Y`O@IJow|c)?mNu?v1>qcolY(t^Lj#iQo>?#O z7PR~6l3I^!!I67Q{L2c={zwKj8uVfsZ^=m{O5xz-_{*)tcf-^|9xq_8%+~QHN@-}f zp*$+Z+Zr|_TCsq7?LfoJia=lJ8<#y6OZcz1xdDi0*<}h52n|>RQ}F4K$70?LCe^!{ zc`*kr77L>^7oZ{V+PDPm#-wV-%X)zjWTkmnD}39RNSVxUgW@|_7ci~hNrZqII`OMI zFp<-3Y9u>uEIkIq`lBz1Y4gvuAn`keN7b<;5E{F_Mivy@yexC$Kb+Fa1YZ^`LXK?J zI}BkJPv77ON1!rrrh(lM;i;UFxECx-Q*>#u`@Bu9Pjd8%_loQK%@yU6qyWQ3wVd9P zWIG8sSoS`KvTMTFEASmP@9;G-(DjVKL4X*6b`|y?^G1bR5zeBIe#-P}O-od8;-j9V zd_y9kd<6SUi^Vzc-w+}HtCORH&o1iOXFp>2q3^EeiXeAMq0F3-2v4MTHYDaGrsWy_ zB;aG={?yK!PBls0gGuX01-YsL_i}+a z4<0;tSnTgTfBwbZ{x`+hi-Y~cFAw(xy(I7`4FOC*-x631k(+21<`u3PLD&2~t|+P; z-Lp}X?bNRi?65Faos*JT%EdUbTIY>4b6Fuf9_kG3f-xA@lJehbw;^S!SdcPZ63P(E za7RoDa-j`%26sVb7Ca{9%Nx`=D^yUC$~o6K!B|T%7>VjiW1_>>xDxA zgeUl0&nFdndy7Y@btq_vvs$967j+Isrk{GK+yq*9YZc7nat(3}oUln%PRW;k;LmXi zGuSOoVJ4x8y>5nL^)p5MG^4rz@OeY8lqKa9Rg!Qaa8XpjrO<{krG#M%909A?mY4(p zkpf^w01;i!y1jv>8Rg{=-GX=7?_>bcaOOCngM~hZeRrB6*{u0raNOufQ7f)F`lrj? z^ySZIC;Q(F&(2Q|_Kx(jf^VleEko|J8@KFAr(IRN!UH4DO}g1>gqW6pd$d|Fm`|t| z@BXu2{MG-9k4gedBsa_N)~q4?1bX>o+z$9vR{-NNbaxzD(~Fua&7=g2K5<1u?;9{a z1YF~F1DSI=d+ftReoIRWO~0~m;5Nq2y*Zs!3igh;v;6${mz|%V{L=ID86{TRYlPi+ z?EUSU0p{@kbbu9<>lc*=fN)gaUQ~iU3Na)N1(BHZUzM}u3h4eGnYn^#};JPiYUouB+ASZ2=I3Hw?9EGBrMS?=gIYnwZaL% z6WgL$Cr9v}T17VMu1C<+h{&qwMy_*B_JgDGAF~z-$B9$oiuPD41 zA#lb(hX`mS;78@4jRwV@ZYmU?G(;i$MZn-hH3jkfU_(L-;7IR&KShS_RJX?cZ|Raw zj+F9($6kaRA8^`s=W1wK%!1bjcg-wq(izBm3}|D!|B)nsX+Or!8NH(JOqks_gtjw? ziVD%a23NjG1sqc1)%0JaUN0Op6{T&(IE*yzMxTo!rR;g9umA#vhaWO!#Ut<^HhOhH zatL=KmEo_PfQ_96P#jyBs0Vk4;O_434#C~s-GX}{!6jI54IbPzxI4jZaCaT_k*&LL z@2!2i_tjKSbx%!w{nu$ZJ^wlLf3H7`0C*fPp>f)rq@NK2Xl|62$$!Tp28~|64V&Vp z*|#T%vtR__J&XPhZ{BO1fEG&nz|@coRivDNTWmM$gS?x46J?SPmv;&MEX^~Y0@|v* zw7>QSeVzt*pe@KceXZ~eTVE;8rTAVEW}`JWN4Ux`3RSiCH7^XG2)B zjY3kpA>Zuuv7Q6HEPg!=_*}F=zK%z+&xkLc5nSkj(L|3<)kySns|N&^$^5}p%9S}A zCJlH-j2#!=JCf9Xhi+Cf%3$9gut%ha70r=!M5_5jY21EQpQ0NMhMuXps+bXLIMz#- z{A$jd6^u50ZFTgxuo+D3eeP3-Z&yL(E^<-L07fkKFRz*T!C(mfbN$O6* zn2&ugwZDQ}cS#-xi_xkA4$a+InG(d{z?2y2JSYcV#+A8DP-BqS^A3hd{s`3cA`<6xs7OA?@TER#2})$U>+fW=4P zM(db#zMpsI?RO&9joyWQfUGgT8m47OlPIci)@uQh+N<4wGzj39)rCq^q5-bJIzby5 zwo}hSlna$}WlV%^1;95{(ZmKsp6pS)P5I?7E83uwP$8Ct{Nc@YHZ8(_S^Gf&mPo)~n zX#unMoKIlY?gvj1TI#h>HW~C{$x@-xAjt^mVuSN70t23b7RmE4B@O`|R-I4Q3^lcg z%y35ps4ATcz75w0!^b6XLVD|nGdl(ZZjCwmdoI#WbMhjsrF!})xAl1>-&cCvp67F% z$En>q`wx;6zkl@}e~kQb!-;f=4d>^tg=-wn0p8-bhb^pXIp}-dx5sj!#`hhg?L0np zgzvk^-bhZ)56rR)lftqr!I&m&rP&kk$uDtKUU(3ls4wmSFcAdy2*`nWKN;Lq>&0xX z;5o2Cn!Amit@b_JZ0y}kf+wxz0GXX$Y?+(G6tI%)0l1g zj3WHD{K*kh7F3$6YfO!39cE{oyMnzVUu5czvt{_0u9Rn7y>azGmtm|6?9}m;qREt^ zWmJw#y|l`msLLoqXn+lBL{{gAdhOY#3)zj^)vxwrWt~CC%bz#J0_Gk`O)DK9)wsL4 z`buNK^aTUHJPBkQ-mBR%%*j9HUp}?X)gGF7A5(HRvZVJp_4D$%*PY8?0Z?=VNrW<3 zC9j1D@B9b!3BV`@aIKlaXGcCj>%)%IiCPoqw|%)v<|s#J#O+c^#LFE8Q_bZStzC;s z!}=vTB*=I1wlgQgE}()BgfS%jiOtG;VNrtp)ck#zSj;ajo}}bj5t5Vw@-Cp=^&YxD z$g|z{BoYSeEkBpTbE_>CeCBjjtkb+Cs~+|io!dtluhqty>U9shdGNx*s@kUf34wi# zVrBwbQVNPP$;KFNeIfZHxh(`ZTRg3=_HN1pR;K_5qCI#pQzff?l!`Yi-)SvDIS;ta zb|g~+P|A%TD8ioiTu5FkofOR|W+Y2bzOWGPzu>I=xg;tT2Mie++7jGYf*BU9+QEc} zBu?8q*kpde8mnW8>lgtjhHGP{8mCoB>DwhGlY`_ybQ^^hu6X5(@@lMh1$=B4u5tAF z?ddjHBKd$0AIuZ2Wn1skhqRJL6ijO?j!Hhi^$LWK}qJ8$i~+T zhOH5QDJO7)F}sgXYf+sivq&vHKqg`=cW#;$vW`oAyeQ0DZ>_AuUXS8*Jt8Z;tjc|M zuWYr>raY(m)Q1c~|0nQgQ-h^m?Hw7X%+vtb)zk}imwL56(@nk}_jiHL?|SucecfD( zAs4@xM~aBd%)?2)(o6{F8X>xw?U=X;QhrG?PewHaj{e7(Xvik$o|(sgs=mE5H3_Q4 zlN^OuDy%d~W=U0$ohb+m^y*}UGTF)NNBUUyK0$)rBeVDkOJRu4ydpYajV1IA<)1}hOCSZRG|&J5Ri3{oM4VlnJ~^2= z+5bxc*arUh0OelJ3-2fcpn-RsbzGTde*cpa(7i8^F+cie8YbHca_S@+9 z-(SeqaT1$aNZh+LQidf|3?Ch=66O4wm+^yf9-s_f`ovPo+@XI9#MtEda&{hO+LjeAkkhy{qk5zk2uW53j z&8mxe^YC+0^b$YlF*kakiU3~tPC8!kE-`(6w2T0;fH`m%VKdNJ_2QH__S)yc5@#JT z5M?kC%3Cx$OZaX7+9===1QOmI^CGj#J}5QXDZAXU6~JX)@A<($0W=?&o5oU01nxSs z24HBjD5Z<^!XYai^g&5|23dRIhZbn{;>NyS!XUn8LF#-71HT7CCib5=c=1C8^c?md z8`r}Qeuvd-nXu+d>+162gLAT)T`_L-vg?(ogkXUBg|zYY+W3wtb!ql4^LPxq>kZZ! zN^5%q4LAw5zki#7zt7f;w{E~J_Fz!SB-A{AVD5zhq7(Uw&9fBnvjm(GKUx1YY^H>b@>Nc>63EZi6=VHOhvum{HjW zLA$v#unM%{Y_jo@8J?>bt52Ugb-ak+iF%<0k4Z$`XGdxAdd(VhZ2!Ssf|v2k5(8*0 z6jlKQZcItMDeeJ-$_|v1ZAEFC5Ln>*aakpE!CFums&yGu#kAhI;kqv@pGnA$H?4i# z*h6!gOi?}EpsmQ+A&ZO|e*1h#Biglrk1hWsC9a$-R%*{7X0Npke5~bc=%(e-F_!Tp z#|*$AkqQEpxApJdG|84sC}WgK&^Sttr$JhE(NyhDJ4|5npx~!XENiP}`eNyCG+KJ_ zkpFDum2`nuJjigl(6hUOP=x*!_jyIp@{Y%!*HO-7ys$bIsyAbPm0({4M*P$44}qU# zg^1liIUc*~wy615hmJ51T>K26g{U{}4E&>Q?y>O4z69 zxP1uP$bm5s`2s!TdWKoYaXDv#)iXV+Zh+6s;CkH$NNz5}*F+S$^Z;i2p21 z>ou9|VW=!tv^<6q;Ri^S=^}`|Yy^=JR00^UT&T$)q3aWaavKpJI0Y@ck3zS@E+SOk z0|MVRGg?fLJd_e4wsW7kg@f^`54?AvsJ!}vTU;9^INDN7y=S&_vLU(M#wxLD&uXyK z;2amohE7hQ+#^?^n2suyzY8HDAGbH#T~ZC(grt9?@T}l5Z!hiYi53)rN~`g?#fC|^ zFWOu0*w;WHs&5ha?rgcwhBRJRZN9WzVP)A(6T~O0_<_OnQm|JBx)8!9VIs;4U;N8C z63?;n^~9o39$H9xx*%#Osj?Tga=zK1iF+s>p}!2A#eJ;(6~$%;KqydiADkb z`Yg(rV-l8EB9`~43m-u0vC(T<6Gut@eDf@PV;d~mjR{`lJpLsOXG}k?%2ak(_cxYk zX>z0Nkb!%N_X+=I)f~)+aEu;`Tkm^c9~;ugSXN?^T}S&T|h$MG6~7YykGQ8 zjDD}9)!8m>|1m$F%3{fr_7s)^GsLy#5VA=3=q!Fgh$unP?u$J%F1;6=Z%{jn*}Zca zf>Z(F0sIZ|bi(GIqc?HH^KS~WQgYpUWrytNtHp2AUD>V{?O#5?!;HNV!%NepM4}mE z5#&)KR^Rq~zE84`rd-sEtc%gBGub+Jl$j%|L-U84Tu8Cv@J))@>|^cwvF|`T*X__B zmi2+%;rc%d=+JfhVhi?q!F z8`Wrio^|GMr%5E7PG7(75!F~c?R$@2!hr%`a;7&zEcOVN1G?ju&g+Dn&5 z$ZEAdc2!?yhiF+&OK)5?Ajlj}@%8z4$x<=C>0u4(f@n}e0ZUes&RTPIlo zyXnjlRi#KS3AXcxB&UK|WLoJ)d_P<~5v^dS7M|b?&b3qYu+k$IXsg8Np&niNMSPBd zbZOTf9=-w&?4MoX&rxeOnpz|p43mD_5)Dl*6j5_b6>5!jMNfgtowD8K_;3hcz#fD% zq8$u>S6R*~XM`6I$<&6b*>WB0XiNe%Z#*Vsn=o$X`!Whui>D^tIt^@fVRg4eMP`+OeZw6%&fc%@AXeR}z|8#iZ2*!bKPDBc z#XQb(7lO`7j{Z(0&Lj!=g@?zly;<5;VCM$(<#zZW5iR)Z(3Zc#Lk07rV^W$}WlsWD zd27Cc=jsSI9b^RF6uby{u_AZlPcO)r+|^*HmaHq;2wClx_VlLwj71l6=R9LcmF3m8 zmftB2B6$)N1AWIxM0B&52ym%b-ibZs6$gdd2Z=V&77ozxG;)3tq;908IN7jR_Py{t zKzYHRsICWg`m)g|w5qM|^7;l^SjA~{ibY& zI-@BjQ4whJhb-cAp6P#XyS|^N5RPM6Ir-Art8q>_<2ea^o)I=zITZ{OLp;j5QnC2# zY$R$%H}=k8Q@_14KBxnK#jMf$_51>14x;&>n|NtIr}G*1Jn}aEo@>Ks7qkL8HNIf& z_DF7}W<*=H%{ z?G?KBJ`*}{j!@kqI7ws1^U;S&sO3?Rp3xhOFT*@eFb|{H1XX(!6MZ`gnj9bi?=qLR=vB^U@gxYKf3~K^d{}RhHAOeM&hr-efEkJ zU6dz!f0tvA*4?G;HDstF@1?0oDg%OHiz|_>(praSTHxdf5!#`wmdn%yvX zM0#~NgXu&F>vsRFeXJt5h~(99NN1!XB41O*(3SzNVJ!}u;8M9L6YQ_x+S2{-1qs(| zPMgD|LgtcUB1XIE1bR$FjUy7cd4BxjLVGYIp$hg2kOUhr49dtFnGequNSzR9F-+vg z11}He6F53dpgMVdNHMvc#1z+-9nN?P4X(3R+i%E2$~Q61cEJWJ{yG%3Vre#|O|I_g zHg_Kd==*1D#dk8qsfd;?iW{+BhC=hz2`QorHX}IDbnzGwF8T1PG5KT`k?IW514D;I zQgCQfp0%yzM_cK%Z8^Ogx#Q=jKV|mxb99uFzE6SonNitK?KP^F3iGbc|SZRHR=#<+vtZ*0nhXu+I;}F!O;IF>50iCbB)f`PJzo%p}D>0V5 zvU8t$6C(FI!S+L(`-1n-O5s%!;!z9vz(uB3i`pqqp ztJT#Zul69(K$S=C1--E1^-C-#$qY+DHXl}0Jscj4^im@*Ug=SMovG4ugt{%Vt7?Fc ztqI+-IoqOI{ck?Kt1}mX$052CWx$hZ2;S`i*HAA^DiZfG9vYu%yfM{8c@@Xw_r~S! zeWR6!Pz@CC%S(Pwd0DFYv_eC*nvJ?8p+R4t4XTWdUnX{uPC%$)E7|-re7`mvw*1*Z zN(j)PKrPr;1M6*EX{3twlna@MCV{!iA>vw1MBq8Mrm-AaW8Y%>Haha!madnAf-!9= zllMJAyDiNgs4+z|*?8f@5p#7=#{oGs8}5^_dnV8bRL^^o>;ms5z;WJ9tL3IYnt>)b zImDuz*0vp&d&r7i zR4&MsL|yWZyJ@SRwMwoE51W}3Mib1Yq}EC0Ci4yEztS^*{OaN@#10Km008WJiU#FB zRjO__R__Ane>137Xeq?4aG`b_YoQKH>hi1;`OqLfOpz(nh@=A)T;ai)8CAII&`T54 zxnA%6#lK-%@DG2eU`mmlSjgb+op!t1+Jh>0C!aIv6>1N>-NM$T%@*{%D|Q@EWbxdI zMUo2q7E>L~PX7cps!4=-&Tzc5?(C+E-T$=R&gZ-zsv@s@Cz*-?Hyn$<$8iG(t3#c{ z(&Xdxr8|NKz*((dFLf_Z#~;4Tx;P6h+=zWvl3_DT7z5n-v>H%x9A>(VuP7X7xX>Cl z(QMbu7oJapH(i@89G+!A;-T)|K#N1PM0E>d6O?L33r-R>?pV>8ZfG-TFug{9F)pQJ zi?Gh4S#S5hemXm&T%nqEsz+a>If~UAUH6{)*zp$a9AcAuf!|b%9A10)V^VS)%boyM zEabbc48U3kSF{0ZyUZk}9*`^K8WHh zTZJvEH=DPCz;#?I0HTmSYvSbDqF1NOju7A%I6#tZ!f*FZ4X5*$Vx30QNrsC1UmT(A zEV@%qLdtsh`4|OAvAK(kD>@Tq;h*f?-I1)htg2l(;RiCjp73}4I_1@o;RnmTAsh4I zk-VWs31@E0QlB~XprjNsgMKma%q2vZJE1SHzs>BJ?XZ5Zfzk);M}YDVX|K{3q%S=B z@5pmiXc!Iobr1s8{ce+BpwpN*?|cCTD-2*=_RL}qx9!LCa!q?7dUuzz1acM0pTUSn zGLmYO8mliMRQ$qd2?8TbUdpuZl z;Zm?whIT&eHc}i&}Ec-vdsFpsc4KYsqCQla%AEfvG}Yn9rD1Xb2=CWdbt++d-2 zx7kmLCEP6>1+Z6bl1aQ}YH?tG2;V!AD1G@+S6l^3wzhrwJqYzh4j*)v2}28t1hwR@ z*ghz7QV;G6H+&uMAumfT>6be%0b@Y$pcRd?cOz|FlY(K6^YKVujMygUqi33A8n2v= zNBgJ)z2@dtQr6<1tMW=xfw}CyHj3)@F zL~*5}*3jCjK4)S7a|Yx1`?;3j&}$>9Om01=8Q=*1o?=!4qEK4^VtaWMbl`h zqNiCTkDlo{0oV5M`rlY+!26diPe+Bz<_lJMCc>@4{3~iDY38vWZ>wlL2d-wOSaCLhZX#Y!Y z-pJ^$?-?192$E2(g*xq5J z)_`}j+~29N(0@_Q>}@O@Ke<`BzK8jr8MkdGBy-eEuuce`$n2*ni5||K14DRDWUr zp=$s4ru{?xr(6B+)U5xc{y$CY|4#q!O2EIi!Ph%lixZ_iONfiWKVoR0jay-(NQG LCzn+IpRNA`+u!Y8 literal 0 HcmV?d00001 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7c664966ed74e..dbb463f6005a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -998,8 +998,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) - return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7bed5216eabf3..ebdd665e349c5 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -29,7 +29,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main -from pyspark.serializers import read_int, write_int +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer def compute_real_exit_code(exit_code): @@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock): +def worker(sock, authenticated): """ Called by a worker process after the fork(). """ @@ -56,6 +56,18 @@ def worker(sock): # otherwise writes also cause a seek that makes us miss data on the read side. infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) + + if not authenticated: + client_secret = UTF8Deserializer().loads(infile) + if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret: + write_with_length("ok".encode("utf-8"), outfile) + outfile.flush() + else: + write_with_length("err".encode("utf-8"), outfile) + outfile.flush() + sock.close() + return 1 + exit_code = 0 try: worker_main(infile, outfile) @@ -153,8 +165,11 @@ def handle_sigterm(*args): write_int(os.getpid(), outfile) outfile.flush() outfile.close() + authenticated = False while True: - code = worker(sock) + code = worker(sock, authenticated) + if code == 0: + authenticated = True if not reuse or code: # wait for closing try: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3e704fe9bf6ec..0afbe9dc6aa3e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -21,16 +21,19 @@ import select import signal import shlex +import shutil import socket import platform +import tempfile +import time from subprocess import Popen, PIPE if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_gateway import java_import, JavaGateway, GatewayParameters from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer def launch_gateway(conf=None): @@ -41,6 +44,7 @@ def launch_gateway(conf=None): """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -59,40 +63,40 @@ def launch_gateway(conf=None): ]) command = command + shlex.split(submit_args) - # Start a socket that will be used by PythonGatewayServer to communicate its port to us - callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - callback_socket.bind(('127.0.0.1', 0)) - callback_socket.listen(1) - callback_host, callback_port = callback_socket.getsockname() - env = dict(os.environ) - env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host - env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) - - # Launch the Java gateway. - # We open a pipe to stdin so that the Java gateway can die when the pipe is broken - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) - else: - # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) - - gateway_port = None - # We use select() here in order to avoid blocking indefinitely if the subprocess dies - # before connecting - while gateway_port is None and proc.poll() is None: - timeout = 1 # (seconds) - readable, _, _ = select.select([callback_socket], [], [], timeout) - if callback_socket in readable: - gateway_connection = callback_socket.accept()[0] - # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile(mode="rb")) - gateway_connection.close() - callback_socket.close() - if gateway_port is None: - raise Exception("Java gateway process exited before sending the driver its port number") + # Create a temporary directory where the gateway server should write the connection + # information. + conn_info_dir = tempfile.mkdtemp() + try: + fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) + os.close(fd) + os.unlink(conn_info_file) + + env = dict(os.environ) + env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdin=PIPE, env=env) + + # Wait for the file to appear, or for the process to exit, whichever happens first. + while not proc.poll() and not os.path.isfile(conn_info_file): + time.sleep(0.1) + + if not os.path.isfile(conn_info_file): + raise Exception("Java gateway process exited before sending its port number") + + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + finally: + shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -111,7 +115,9 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, + auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") @@ -126,3 +132,16 @@ def killChild(): java_import(gateway.jvm, "scala.Tuple2") return gateway + + +def do_server_auth(conn, auth_secret): + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise Exception("Unexpected reply from iterator server.") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..d5a237a5b2855 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ else: from itertools import imap as map, ifilter as filter +from pyspark.java_gateway import do_server_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ + UTF8Deserializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -136,7 +138,8 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(port, serializer): +def _load_from_socket(sock_info, serializer): + port, auth_secret = sock_info sock = None # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. @@ -156,8 +159,12 @@ def _load_from_socket(port, serializer): # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + + sockfile = sock.makefile("rwb", 65536) + do_server_auth(sockfile, auth_secret) + # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sock.makefile("rb", 65536)) + return serializer.load_stream(sockfile) def ignore_unicode_prefix(f): @@ -822,8 +829,8 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) - return list(_load_from_socket(port, self._jrdd_deserializer)) + sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) def reduce(self, f): """ @@ -2380,8 +2387,8 @@ def toLocalIterator(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(sock_info, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 16f8e52dead7b..213dc158f9328 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -463,8 +463,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + sock_info = self._jdf.collectToPython() + return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(2.0) @@ -477,8 +477,8 @@ def toLocalIterator(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + sock_info = self._jdf.toPythonIterator() + return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) @@ -2087,8 +2087,8 @@ def _collectAsArrow(self): .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + sock_info = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(sock_info, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a1a4336b1e8de..8bb63fcc7ff9c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.java_gateway import do_server_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -301,9 +302,11 @@ def process(): if __name__ == '__main__': - # Read a local port to connect to from stdin - java_port = int(sys.stdin.readline()) + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) sock_file = sock.makefile("rwb", 65536) + do_server_auth(sock_file, auth_secret) main(sock_file, sock_file) diff --git a/python/setup.py b/python/setup.py index 794ceceae3008..d309e0564530a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.6'], + install_requires=['py4j==0.10.7'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bafb129032b49..134b3e5fef11a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1152,7 +1152,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a129be7c06b53..59b0f29e37d84 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -265,7 +265,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.6-src.zip", + s"$sparkHome/python/lib/py4j-0.10.7-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index bac154e10ae62..bf3da18c3706e 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cd4def71e6f3b..d518e07bfb62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3187,7 +3187,7 @@ class Dataset[T] private[sql]( EvaluatePython.javaToPython(rdd) } - private[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) @@ -3200,7 +3200,7 @@ class Dataset[T] private[sql]( /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ - private[sql] def collectAsArrowToPython(): Int = { + private[sql] def collectAsArrowToPython(): Array[Any] = { withAction("collectAsArrowToPython", queryExecution) { plan => val iter: Iterator[Array[Byte]] = toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) @@ -3208,7 +3208,7 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Array[Any] = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } From 628c7b517969c4a7ccb26ea67ab3dd61266073ca Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 17 Apr 2018 13:29:43 -0700 Subject: [PATCH 459/593] [SPARKR] Match pyspark features in SparkR communication protocol. --- R/pkg/R/client.R | 4 +- R/pkg/R/deserialize.R | 10 +++- R/pkg/R/sparkR.R | 39 +++++++++++-- R/pkg/inst/worker/daemon.R | 4 +- R/pkg/inst/worker/worker.R | 5 +- .../org/apache/spark/api/r/RAuthHelper.scala | 38 +++++++++++++ .../org/apache/spark/api/r/RBackend.scala | 43 +++++++++++++-- .../spark/api/r/RBackendAuthHandler.scala | 55 +++++++++++++++++++ .../org/apache/spark/api/r/RRunner.scala | 35 ++++++++---- .../org/apache/spark/deploy/RRunner.scala | 6 +- 10 files changed, 210 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala create mode 100644 core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 9d82814211bc5..7244cc9f9e38e 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout) { +connectBackend <- function(hostname, port, timeout, authSecret) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) { con <- socketConnection(host = hostname, port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = timeout) - + doServerAuth(con, authSecret) assign(".sparkRCon", con, envir = .sparkREnv) con } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index a90f7d381026b..cb03f1667629f 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) { stop(paste("Unsupported type for deserialization", type))) } -readString <- function(con) { - stringLen <- readInt(con) - raw <- readBin(con, raw(), stringLen, endian = "big") +readStringData <- function(con, len) { + raw <- readBin(con, raw(), len, endian = "big") string <- rawToChar(raw) Encoding(string) <- "UTF-8" string } +readString <- function(con) { + stringLen <- readInt(con) + readStringData(con, stringLen) +} + readInt <- function(con) { readBin(con, integer(), n = 1, endian = "big") } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index a480ac606f10d..38ee79477996f 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -158,6 +158,10 @@ sparkR.sparkContext <- function( " please use the --packages commandline instead", sep = ",")) } backendPort <- existingPort + authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET") + if (nchar(authSecret) == 0) { + stop("Auth secret not provided in environment.") + } } else { path <- tempfile(pattern = "backend_port") submitOps <- getClientModeSparkSubmitOpts( @@ -186,16 +190,27 @@ sparkR.sparkContext <- function( monitorPort <- readInt(f) rLibPath <- readString(f) connectionTimeout <- readInt(f) + + # Don't use readString() so that we can provide a useful + # error message if the R and Java versions are mismatched. + authSecretLen = readInt(f) + if (length(authSecretLen) == 0 || authSecretLen == 0) { + stop("Unexpected EOF in JVM connection data. Mismatched versions?") + } + authSecret <- readStringData(f, authSecretLen) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || length(monitorPort) == 0 || monitorPort == 0 || - length(rLibPath) != 1) { + length(rLibPath) != 1 || length(authSecret) == 0) { stop("JVM failed to launch") } - assign(".monitorConn", - socketConnection(port = monitorPort, timeout = connectionTimeout), - envir = .sparkREnv) + + monitorConn <- socketConnection(port = monitorPort, blocking = TRUE, + timeout = connectionTimeout, open = "wb") + doServerAuth(monitorConn, authSecret) + + assign(".monitorConn", monitorConn, envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -205,7 +220,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort, timeout = connectionTimeout) + connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret) }, error = function(err) { stop("Failed to connect JVM\n") @@ -687,3 +702,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { NULL } } + +# Utility function for sending auth data over a socket and checking the server's reply. +doServerAuth <- function(con, authSecret) { + if (nchar(authSecret) == 0) { + stop("Auth secret not provided.") + } + writeString(con, authSecret) + flush(con) + reply <- readString(con) + if (reply != "ok") { + close(con) + stop("Unexpected reply from server.") + } +} diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 2e31dc5f728cd..fb9db63b07cd0 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) + port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout) + +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # Waits indefinitely for a socket connecion by default. selectTimeout <- NULL diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 00789d815bba8..ba458d2b9ddfb 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) + outputCon <- socketConnection( port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala new file mode 100644 index 0000000000000..ac6826a9ec774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.SparkConf +import org.apache.spark.security.SocketAuthHelper + +private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) { + + override protected def readUtf8(s: Socket): String = { + SerDe.readString(new DataInputStream(s.getInputStream())) + } + + override protected def writeUtf8(str: String, s: Socket): Unit = { + val out = s.getOutputStream() + SerDe.writeString(new DataOutputStream(out), str) + out.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 2d1152a036449..3b2e809408e0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -45,7 +47,7 @@ private[spark] class RBackend { /** Tracks JVM objects returned to R for this RBackend instance. */ private[r] val jvmObjectTracker = new JVMObjectTracker - def init(): Int = { + def init(): (Int, RAuthHelper) = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) @@ -53,6 +55,7 @@ private[spark] class RBackend { conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) + val authHelper = new RAuthHelper(conf) bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) @@ -71,13 +74,16 @@ private[spark] class RBackend { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) + .addLast(new RBackendAuthHandler(authHelper.secret)) .addLast("handler", handler) } }) channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + + val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + (port, authHelper) } def run(): Unit = { @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging { val sparkRBackend = new RBackend() try { // bind to random port - val boundPort = sparkRBackend.init() + val (boundPort, authHelper) = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.writeInt(backendConnectionTimeout) + SerDe.writeString(dos, authHelper.secret) dos.close() f.renameTo(new File(path)) @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging { val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) + + // Wait for the R process to connect back, ignoring any failed auth attempts. Allow + // a max number of connection attempts to avoid looping forever. try { - val inSocket = serverSocket.accept() + var remainingAttempts = 10 + var inSocket: Socket = null + while (inSocket == null) { + inSocket = serverSocket.accept() + try { + authHelper.authClient(inSocket) + } catch { + case e: Exception => + remainingAttempts -= 1 + if (remainingAttempts == 0) { + val msg = "Too many failed authentication attempts." + logError(msg) + throw new IllegalStateException(msg) + } + logInfo("Client connection failed authentication.") + inSocket = null + } + } + serverSocket.close() + // wait for the end of socket, closed if R process die inSocket.getInputStream().read(buf) } finally { + serverSocket.close() sparkRBackend.close() System.exit(0) } @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging { } System.exit(0) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala new file mode 100644 index 0000000000000..4162e4a6c7476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Authentication handler for connections from the R process. + */ +private class RBackendAuthHandler(secret: String) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + // The R code adds a null terminator to serialized strings, so ignore it here. + val clientSecret = new String(msg, 0, msg.length - 1, UTF_8) + try { + require(secret == clientSecret, "Auth secret mismatch.") + ctx.pipeline().remove(this) + writeReply("ok", ctx.channel()) + } catch { + case e: Exception => + logInfo("Authentication failure.", e) + writeReply("err", ctx.channel()) + ctx.close() + } + } + + private def writeReply(reply: String, chan: Channel): Unit = { + val out = new ByteArrayOutputStream() + SerDe.writeString(new DataOutputStream(out), reply) + chan.writeAndFlush(out.toByteArray()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 88118392003e8..e7fdc3963945a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -74,14 +74,19 @@ private[spark] class RRunner[U]( // the socket used to send out the input of task serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() + dataStream = try { + val inSocket = serverSocket.accept() + RRunner.authHelper.authClient(inSocket) + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + RRunner.authHelper.authClient(outSocket) + val inputStream = new BufferedInputStream(outSocket.getInputStream) + new DataInputStream(inputStream) + } finally { + serverSocket.close() + } try { return new Iterator[U] { @@ -315,6 +320,11 @@ private[r] object RRunner { private[this] var errThread: BufferedStreamThread = _ private[this] var daemonChannel: DataOutputStream = _ + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new RAuthHelper(conf) + } + /** * Start a thread to print the process's stderr to ours */ @@ -349,6 +359,7 @@ private[r] object RRunner { pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") + pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) @@ -370,8 +381,12 @@ private[r] object RRunner { // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() + try { + authHelper.authClient(sock) + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + } finally { + serverSocket.close() + } } try { daemonChannel.writeInt(port) diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 6eb53a8252205..e86b362639e57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -68,10 +68,13 @@ object RRunner { // Java system properties etc. val sparkRBackend = new RBackend() @volatile var sparkRBackendPort = 0 + @volatile var sparkRBackendSecret: String = null val initialized = new Semaphore(0) val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackendPort = sparkRBackend.init() + val (port, authHelper) = sparkRBackend.init() + sparkRBackendPort = port + sparkRBackendSecret = authHelper.secret initialized.release() sparkRBackend.run() } @@ -91,6 +94,7 @@ object RRunner { env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() From 7aaa148f593470b2c32221b69097b8b54524eb74 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 9 May 2018 11:09:19 -0700 Subject: [PATCH 460/593] [SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs ## What changes were proposed in this pull request? Provide evaluateEachIteration method or equivalent for spark.ml GBTs. ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21097 from WeichenXu123/GBTeval. --- .../ml/classification/GBTClassifier.scala | 15 +++++++++ .../spark/ml/regression/GBTRegressor.scala | 17 +++++++++- .../org/apache/spark/ml/tree/treeParams.scala | 6 +++- .../classification/GBTClassifierSuite.scala | 29 ++++++++++++++++- .../ml/regression/GBTRegressorSuite.scala | 32 +++++++++++++++++-- 5 files changed, 94 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0aa24f0a3cfcc..3fb6d1e4e4f3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -334,6 +334,21 @@ class GBTClassificationModel private[ml]( // hard coded loss, which is not meant to be changed in the model private val loss = getOldLossType + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, + OldAlgo.Classification + ) + } + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8598e808c4946..d7e054bf55ef6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ /** @@ -269,6 +269,21 @@ class GBTRegressionModel private[ml]( new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + * @param loss The loss function used to compute error. Supported options: squared, absolute + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, + convertToOldLossType(loss), OldAlgo.Regression) + } + @Since("2.0.0") override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 81b6222acc7ce..ec8868bb42cbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -579,7 +579,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { - getLossType match { + convertToOldLossType(getLossType) + } + + private[ml] def convertToOldLossType(loss: String): OldLoss = { + loss match { case "squared" => OldSquaredError case "absolute" => OldAbsoluteError case _ => diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index f0ee5496f9d1d..e20de196d65ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.RegressionLeafNode -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -365,6 +365,33 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + val gbt = new GBTClassifier() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType("logistic") + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTClassificationModel("gbt-cls-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses) + val model2 = new GBTClassificationModel("gbt-cls-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses) + + val evalArr = model3.evaluateEachIteration(validationData.toDF) + val remappedValidationData = validationData.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData, + model1.trees, model1.treeWeights, model1.getOldLossType) + val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData, + model2.trees, model2.treeWeights, model2.getOldLossType) + val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData, + model3.trees, model3.treeWeights, model3.getOldLossType) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index fad11d078250f..773f6d2c542fe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -201,7 +202,34 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } - + test("model evaluateEachIteration") { + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType(lossType) + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTRegressionModel("gbt-reg-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures) + val model2 = new GBTRegressionModel("gbt-reg-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures) + + for (evalLossType <- GBTRegressor.supportedLossTypes) { + val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType) + val lossErr1 = GradientBoostedTrees.computeError(validationData, + model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType)) + val lossErr2 = GradientBoostedTrees.computeError(validationData, + model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType)) + val lossErr3 = GradientBoostedTrees.computeError(validationData, + model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType)) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + } + } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load From fd1179c17273283d32f275d5cd5f97aaa2aca1f7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 9 May 2018 11:32:17 -0700 Subject: [PATCH 461/593] [SPARK-24214][SS] Fix toJSON for StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation ## What changes were proposed in this pull request? We should overwrite "otherCopyArgs" to provide the SparkSession parameter otherwise TreeNode.toJSON cannot get the full constructor parameter list. ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #21275 from zsxwing/SPARK-24214. --- .../execution/streaming/StreamingRelation.scala | 3 +++ .../spark/sql/streaming/StreamingQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f02d3a2c3733f..24195b5657e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -66,6 +66,7 @@ case class StreamingExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString @@ -97,6 +98,7 @@ case class StreamingRelationV2( output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = sourceName @@ -116,6 +118,7 @@ case class ContinuousExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 0cb2375e0a49a..57986999bf861 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -831,6 +831,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + + "should not fail") { + val df = spark.readStream.format("rate").load() + assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) + + testStream(df)( + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + ) + + testStream(df, useV2Sink = true)( + StartStream(trigger = Trigger.Continuous(100)), + AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation")) + ) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From 9e3bb313682bf88d0c81427167ee341698d29b02 Mon Sep 17 00:00:00 2001 From: wuyi Date: Wed, 9 May 2018 15:44:36 -0700 Subject: [PATCH 462/593] [SPARK-24141][CORE] Fix bug in CoarseGrainedSchedulerBackend.killExecutors ## What changes were proposed in this pull request? In method *CoarseGrainedSchedulerBackend.killExecutors()*, `numPendingExecutors` should add `executorsToKill.size` rather than `knownExecutors.size` if we do not adjust target number of executors. ## How was this patch tested? N/A Author: wuyi Closes #21209 from Ngone51/SPARK-24141. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5627a557a12f3..d8794e8e551aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -633,7 +633,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } doRequestTotalExecutors(requestedTotalExecutors) } else { - numPendingExecutors += knownExecutors.size + numPendingExecutors += executorsToKill.size Future.successful(true) } From 9341c951e85ff29714cbee302053872a6a4223da Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 9 May 2018 19:56:03 -0700 Subject: [PATCH 463/593] [SPARK-23852][SQL] Add test that fails if PARQUET-1217 is not fixed ## What changes were proposed in this pull request? Add a new test that triggers if PARQUET-1217 - a predicate pushdown bug - is not fixed in Spark's Parquet dependency. ## How was this patch tested? New unit test passes. Author: Henry Robinson Closes #21284 from henryr/spark-23852. --- .../test/resources/test-data/parquet-1217.parquet | Bin 0 -> 321 bytes .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++++++ 2 files changed, 10 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/parquet-1217.parquet diff --git a/sql/core/src/test/resources/test-data/parquet-1217.parquet b/sql/core/src/test/resources/test-data/parquet-1217.parquet new file mode 100644 index 0000000000000000000000000000000000000000..eb2dc4f79907019ffe4edaa9adb951924aad1c5d GIT binary patch literal 321 zcmbu5(MrQG7=^Pml#v^~@DB~-BFI`Q)R6|)b+DV=$S%Yc=L=*_hK0@zhq6oG%h&Ne zT(Vd2z~P6(;p68ti 0").count() === 2) + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 62d01391fee77eedd75b4e3f475ede8b9f0df0c2 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 9 May 2018 21:48:54 -0700 Subject: [PATCH 464/593] [SPARK-24073][SQL] Rename DataReaderFactory to InputPartition. ## What changes were proposed in this pull request? Renames: * `DataReaderFactory` to `InputPartition` * `DataReader` to `InputPartitionReader` * `createDataReaderFactories` to `planInputPartitions` * `createUnsafeDataReaderFactories` to `planUnsafeInputPartitions` * `createBatchDataReaderFactories` to `planBatchInputPartitions` This fixes the changes in SPARK-23219, which renamed ReadTask to DataReaderFactory. The intent of that change was to make the read and write API match (write side uses DataWriterFactory), but the underlying problem is that the two classes are not equivalent. ReadTask/DataReader function as Iterable/Iterator. One InputPartition is a specific partition of the data to be read, in contrast to DataWriterFactory where the same factory instance is used in all write tasks. InputPartition's purpose is to manage the lifecycle of the associated reader, which is now called InputPartitionReader, with an explicit create operation to mirror the close operation. This was no longer clear from the API because DataReaderFactory appeared to be more generic than it is and it isn't clear why a set of them is produced for a read. ## How was this patch tested? Existing tests, which have been updated to use the new name. Author: Ryan Blue Closes #21145 from rdblue/SPARK-24073-revert-data-reader-factory-rename. --- .../sql/kafka010/KafkaContinuousReader.scala | 20 ++--- .../sql/kafka010/KafkaMicroBatchReader.scala | 21 ++--- .../sql/kafka010/KafkaSourceProvider.scala | 3 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sql/sources/v2/MicroBatchReadSupport.java | 2 +- ...ory.java => ContinuousInputPartition.java} | 8 +- .../sources/v2/reader/DataSourceReader.java | 10 +-- ...ReaderFactory.java => InputPartition.java} | 16 ++-- ...aReader.java => InputPartitionReader.java} | 4 +- .../v2/reader/SupportsReportPartitioning.java | 2 +- .../v2/reader/SupportsScanColumnarBatch.java | 10 +-- .../v2/reader/SupportsScanUnsafeRow.java | 8 +- .../partitioning/ClusteredDistribution.java | 4 +- .../v2/reader/partitioning/Distribution.java | 6 +- .../v2/reader/partitioning/Partitioning.java | 4 +- ...va => ContinuousInputPartitionReader.java} | 6 +- .../v2/reader/streaming/ContinuousReader.java | 10 +-- .../v2/reader/streaming/MicroBatchReader.java | 2 +- .../datasources/v2/DataSourceRDD.scala | 11 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 46 +++++------ .../continuous/ContinuousDataSourceRDD.scala | 22 +++--- .../ContinuousQueuedDataReader.scala | 13 ++-- .../ContinuousRateStreamSource.scala | 20 ++--- .../sql/execution/streaming/memory.scala | 12 +-- .../sources/ContinuousMemoryStream.scala | 18 ++--- .../sources/RateStreamMicroBatchReader.scala | 15 ++-- .../execution/streaming/sources/socket.scala | 29 +++---- .../sources/v2/JavaAdvancedDataSourceV2.java | 22 +++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 12 +-- .../v2/JavaPartitionAwareDataSource.java | 12 +-- .../v2/JavaSchemaRequiredDataSource.java | 4 +- .../sources/v2/JavaSimpleDataSourceV2.java | 18 ++--- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 16 ++-- .../sources/RateStreamProviderSuite.scala | 12 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 78 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 14 ++-- .../sql/streaming/StreamingQuerySuite.scala | 11 +-- .../ContinuousQueuedDataReaderSuite.scala | 8 +- .../sources/StreamingDataSourceV2Suite.scala | 4 +- 39 files changed, 272 insertions(+), 263 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ContinuousDataReaderFactory.java => ContinuousInputPartition.java} (78%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataReaderFactory.java => InputPartition.java} (81%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataReader.java => InputPartitionReader.java} (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/{ContinuousDataReader.java => ContinuousInputPartitionReader.java} (84%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index f26c134c2f6e9..88abf8a8dd027 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType /** @@ -86,7 +86,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -108,7 +108,7 @@ class KafkaContinuousReader( case (topicPartition, start) => KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[DataReaderFactory[UnsafeRow]] + .asInstanceOf[InputPartition[UnsafeRow]] }.asJava } @@ -161,18 +161,18 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousDataReader( + new KafkaContinuousInputPartitionReader( topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader( + override def createPartitionReader(): KafkaContinuousInputPartitionReader = { + new KafkaContinuousInputPartitionReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } } @@ -187,12 +187,12 @@ case class KafkaContinuousDataReaderFactory( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousDataReader( +class KafkaContinuousInputPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index cbe655f9bff1f..8a377738ea782 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader( new KafkaMicroBatchDataReaderFactory( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } - factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava } override def getStartOffset: Offset = { @@ -299,27 +299,28 @@ private[kafka010] class KafkaMicroBatchReader( } } -/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ private[kafka010] case class KafkaMicroBatchDataReaderFactory( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, reuseKafkaConsumer) } -/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReader( +/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 36b9f0466566b..d225c1ea6b7f1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -149,7 +150,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Creates a [[ContinuousInputPartitionReader]] to read * Kafka data in a continuous streaming query. */ override def createContinuousReader( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index d2d04b68de6ab..871f9700cd1db 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.createUnsafeRowReaderFactories().asScala + val factories = reader.planUnsafeInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 209ffa7a0b9fa..7f4a2c9593c76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -34,7 +34,7 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index a61697649c43e..c24f3b21eade1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -21,15 +21,15 @@ import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can - * implement this interface to provide creating {@link DataReader} with particular offset. + * A mix-in interface for {@link InputPartition}. Continuous input partitions can + * implement this interface to provide creating {@link InputPartitionReader} with particular offset. */ @InterfaceStability.Evolving -public interface ContinuousDataReaderFactory extends DataReaderFactory { +public interface ContinuousInputPartition extends InputPartition { /** * Create a DataReader with particular offset as its startOffset. * * @param offset offset want to set as the DataReader's startOffset. */ - DataReader createDataReaderWithOffset(PartitionOffset offset); + InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index a470bccc5aad2..f898c296e4245 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,8 +31,8 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link DataReaderFactory}s that are returned by - * {@link #createDataReaderFactories()}. + * logic is delegated to {@link InputPartition}s that are returned by + * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -65,8 +65,8 @@ public interface DataSourceReader { StructType readSchema(); /** - * Returns a list of reader factories. Each factory is responsible for creating a data reader to - * output data for one RDD partition. That means the number of factories returned here is same as + * Returns a list of read tasks. Each task is responsible for creating a data reader to + * output data for one RDD partition. That means the number of tasks returned here is same as * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization @@ -76,5 +76,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createDataReaderFactories(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 81% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 32e98e8f5d8bd..c581e3b5d0047 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,20 +22,20 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is + * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is * responsible for creating the actual data reader. The relationship between - * {@link DataReaderFactory} and {@link DataReader} + * {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the reader factory will be serialized and sent to executors, then the data reader - * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be - * serializable and {@link DataReader} doesn't need to be. + * Note that input partitions will be serialized and sent to executors, then the partition reader + * will be created on executors and do the actual reading. So {@link InputPartition} must be + * serializable and {@link InputPartitionReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataReaderFactory extends Serializable { +public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this reader factory can run faster, + * The preferred locations where the data reader returned by this partition can run faster, * but Spark does not guarantee to run the data reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. @@ -57,5 +57,5 @@ default String[] preferredLocations() { * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ - DataReader createDataReader(); + InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index bb9790a1c819e..1b7051f1ad0af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for + * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data @@ -31,7 +31,7 @@ * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving -public interface DataReader extends Closeable { +public interface InputPartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 607628746e873..6b60da7c4dc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -24,7 +24,7 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 2e5cfa78511f0..0faf81db24605 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,22 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); + "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data * in batches. */ - List> createBatchDataReaderFactories(); + List> planBatchInputPartitions(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. + * {@link #planInputPartitions()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 9cd749e8e4ce9..f2220f6d31093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,14 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#planInputPartitions()}, * but returns data in unsafe row format. */ - List> createUnsafeRowReaderFactories(); + List> planUnsafeInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 2d0ee50212b56..38ca5fc6387b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link DataReader}. + * {@link InputPartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index f6b111fdf220d..d2ee9518d628f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link DataReader}). + * partition(the output records of a single {@link InputPartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 309d9e5de0a0f..f460f6bfe3bb9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** @@ -31,7 +31,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. + * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java similarity index 84% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java index 47d26440841fd..7b0ba0bbdda90 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** - * A variation on {@link DataReader} for use with streaming in continuous processing mode. + * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. */ @InterfaceStability.Evolving -public interface ContinuousDataReader extends DataReader { +public interface ContinuousInputPartitionReader extends InputPartitionReader { /** * Get the offset of the current record, or the start offset if no records have been read. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 7fe7f00ac2fa8..716c5c0e9e15a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. + * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -35,7 +35,7 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); @@ -47,7 +47,7 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset deserializeOffset(String json); /** - * Set the desired start offset for reader factories created from this reader. The scan will + * Set the desired start offset for partitions created from this reader. The scan will * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ @@ -61,8 +61,8 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new reader - * factories need to be generated, which may be required if for example the underlying + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index 67ebde30d61a9..0159c731762d9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -33,7 +33,7 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** - * Set the desired offset range for reader factories created from this reader. Reader factories + * Set the desired offset range for input partitions created from this reader. Partition readers * will generate only data within (`start`, `end`]; that is, from the first record after `start` * to the record with offset `end`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f85971be394b1..1a6b32429313a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,14 +22,14 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { @@ -39,7 +39,8 @@ class DataSourceRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition + .createPartitionReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +64,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 77cb707340b0f..c6a7684bf6ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -59,13 +59,13 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => SinglePartition - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => SinglePartition - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => SinglePartition case s: SupportsReportPartitioning => @@ -75,19 +75,19 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala + private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala case _ => - reader.createDataReaderFactories().asScala.map { - new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] + reader.planInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] } } - private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { + private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - r.createBatchDataReaderFactories().asScala + r.planBatchInputPartitions().asScala } private lazy val inputRDD: RDD[InternalRow] = reader match { @@ -95,19 +95,18 @@ case class DataSourceV2ScanExec( EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size)) + .askSync[Unit](SetReaderPartitions(partitions.size)) new ContinuousDataSourceRDD( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - readerFactories) - .asInstanceOf[RDD[InternalRow]] + partitions).asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -132,19 +131,22 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) - extends DataReaderFactory[UnsafeRow] { +class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) + extends InputPartition[UnsafeRow] { - override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations + override def preferredLocations: Array[String] = partition.preferredLocations - override def createDataReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader( - rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + new RowToUnsafeInputPartitionReader( + partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } } -class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) - extends DataReader[UnsafeRow] { +class RowToUnsafeInputPartitionReader( + val rowReader: InputPartitionReader[Row], + encoder: ExpressionEncoder[Row]) + + extends InputPartitionReader[UnsafeRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 0a3b9dcccb6c5..a7ccce10b0cee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,14 +21,14 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} import org.apache.spark.util.{NextIterator, ThreadUtils} class ContinuousDataSourceRDDPartition( val index: Int, - val readerFactory: DataReaderFactory[UnsafeRow]) + val inputPartition: InputPartition[UnsafeRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,12 +51,12 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerFactories.zipWithIndex.map { - case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -75,7 +75,7 @@ class ContinuousDataSourceRDD( if (partition.queueReader == null) { partition.queueReader = new ContinuousQueuedDataReader( - partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -96,17 +96,17 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() } } object ContinuousDataSourceRDD { private[continuous] def getContinuousReader( - reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case r: ContinuousInputPartitionReader[UnsafeRow] => r + case wrapped: RowToUnsafeInputPartitionReader => + wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 01a999f6505fc..d8645576c2052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.Closeable -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.util.control.NonFatal -import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils @@ -38,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - factory: DataReaderFactory[UnsafeRow], + partition: InputPartition[UnsafeRow], context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = factory.createDataReader() + private val reader = partition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -132,7 +131,7 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[DataReader]]. + * a new row arrives to the [[InputPartitionReader]]. */ class DataReaderThread extends Thread( s"continuous-reader--${context.partitionId()}--" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..8d25d9ccc43d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -91,7 +91,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) i, numPartitions, perPartitionRate) - .asInstanceOf[DataReaderFactory[Row]] + .asInstanceOf[InputPartition[Row]] }.asJava } @@ -119,13 +119,13 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReaderFactory[Row] { + extends ContinuousInputPartition[Row] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = { val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] require(rateStreamOffset.partition == partitionIndex, s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousDataReader( + new RateStreamContinuousInputPartitionReader( rateStreamOffset.currentValue, rateStreamOffset.currentTimeMs, partitionIndex, @@ -133,18 +133,18 @@ case class RateStreamContinuousDataReaderFactory( rowsPerSecond) } - override def createDataReader(): DataReader[Row] = - new RateStreamContinuousDataReader( + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamContinuousInputPartitionReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamContinuousDataReader( +class RateStreamContinuousInputPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReader[Row] { + extends ContinuousInputPartitionReader[Row] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6720cdd24b1b2..daa2963220aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -139,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -202,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) - extends DataReaderFactory[UnsafeRow] { - override def createDataReader(): DataReader[UnsafeRow] = { - new DataReader[UnsafeRow] { + extends InputPartition[UnsafeRow] { + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { + new InputPartitionReader[UnsafeRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index a8fca3c19a2d2..fef792eab69d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -34,8 +34,8 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -99,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) ) } - override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = @@ -108,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) startOffset.partitionNums.map { case (part, index) => new ContinuousMemoryStreamDataReaderFactory( - endpointName, part, index): DataReaderFactory[Row] + endpointName, part, index): InputPartition[Row] }.toList.asJava } } @@ -160,9 +160,9 @@ object ContinuousMemoryStream { class ContinuousMemoryStreamDataReaderFactory( driverEndpointName: String, partition: Int, - startOffset: Int) extends DataReaderFactory[Row] { - override def createDataReader: ContinuousMemoryStreamDataReader = - new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition[Row] { + override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = + new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } /** @@ -170,10 +170,10 @@ class ContinuousMemoryStreamDataReaderFactory( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamDataReader( +class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousDataReader[Row] { + startOffset: Int) extends ContinuousInputPartitionReader[Row] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index f54291bea6678..723cc3ad5bb89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") @@ -169,7 +169,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchDataReaderFactory( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] + : InputPartition[Row] }.toList.asJava } @@ -188,19 +188,20 @@ class RateStreamMicroBatchDataReaderFactory( rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { + relativeMsPerValue: Double) extends InputPartition[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, + localStartTimeMs, relativeMsPerValue) } -class RateStreamMicroBatchDataReader( +class RateStreamMicroBatchInputPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { + relativeMsPerValue: Double) extends InputPartitionReader[Row] { private var count = 0 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 90f4a5ba4234d..8240e06d4ab72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -140,7 +140,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") @@ -165,21 +165,22 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR (0 until numPartitions).map { i => val slice = slices(i) - new DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new DataReader[Row] { - private var currentIdx = -1 + new InputPartition[Row] { + override def createPartitionReader(): InputPartitionReader[Row] = + new InputPartitionReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) + override def close(): Unit = {} } - - override def close(): Unit = {} - } } }.toList.asJava } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 172e5d5eebcbe..714638e500c94 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -79,8 +79,8 @@ public Filter[] pushedFilters() { } @Override - public List> createDataReaderFactories() { - List> res = new ArrayList<>(); + public List> planInputPartitions() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -94,33 +94,33 @@ public List> createDataReaderFactories() { } if (lowerBound == null) { - res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { + JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; } @Override - public DataReader createDataReader() { - return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); + public InputPartitionReader createPartitionReader() { + return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index c55093768105b..97d6176d02559 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,14 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchDataReaderFactories() { + public List> planBatchInputPartitions() { return java.util.Arrays.asList( - new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); + new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); } } - static class JavaBatchDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaBatchInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; @@ -59,13 +59,13 @@ static class JavaBatchDataReaderFactory private OnHeapColumnVector j; private ColumnarBatch batch; - JavaBatchDataReaderFactory(int start, int end) { + JavaBatchInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 32fad59b97ff6..e49c8cf8b9e16 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -43,10 +43,10 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -73,12 +73,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { + static class SpecificInputPartition implements InputPartition, InputPartitionReader { private int[] i; private int[] j; private int current = -1; - SpecificDataReaderFactory(int[] i, int[] j) { + SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; @@ -101,7 +101,7 @@ public void close() throws IOException { } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { return this; } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 048d078dfaac4..80eeffd95f83b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 96f55b8a76811..8522a63898a3b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -25,8 +25,8 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new JavaSimpleDataReaderFactory(0, 5), - new JavaSimpleDataReaderFactory(5, 10)); + new JavaSimpleInputPartition(0, 5), + new JavaSimpleInputPartition(5, 10)); } } - static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; - JavaSimpleDataReaderFactory(int start, int end) { + JavaSimpleInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { - return new JavaSimpleDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaSimpleInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index c3916e0b370b5..3ad8e7a0104ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,20 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReaderFactories() { + public List> planUnsafeInputPartitions() { return java.util.Arrays.asList( - new JavaUnsafeRowDataReaderFactory(0, 5), - new JavaUnsafeRowDataReaderFactory(5, 10)); + new JavaUnsafeRowInputPartition(0, 5), + new JavaUnsafeRowInputPartition(5, 10)); } } - static class JavaUnsafeRowDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaUnsafeRowInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowDataReaderFactory(int start, int end) { + JavaUnsafeRowInputPartition(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,8 +59,8 @@ static class JavaUnsafeRowDataReaderFactory } @Override - public DataReader createDataReader() { - return new JavaUnsafeRowDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaUnsafeRowInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..39a010f970ce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -142,9 +142,9 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() + val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() while (dataReader.next()) { data.append(dataReader.get()) @@ -159,11 +159,11 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala - .map(_.createDataReader()) + .map(_.createPartitionReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[Row]() while (reader.next()) buf.append(reader.get()) @@ -304,7 +304,7 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() @@ -314,7 +314,7 @@ class RateSourceSuite extends StreamTest { .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0a53272cd222..505a3f3465c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -346,8 +346,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -359,20 +359,21 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SimpleInputPartition(start: Int, end: Int) + extends InputPartition[Row] + with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) + override def createPartitionReader(): InputPartitionReader[Row] = + new SimpleInputPartition(start, end) override def next(): Boolean = { current += 1 @@ -413,21 +414,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[DataReaderFactory[Row]] + val res = new ArrayList[InputPartition[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(0, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) } res @@ -437,13 +438,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) - extends DataReaderFactory[Row] with DataReader[Row] { +class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) + extends InputPartition[Row] with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = { - new AdvancedDataReaderFactory(start, end, requiredSchema) + override def createPartitionReader(): InputPartitionReader[Row] = { + new AdvancedInputPartition(start, end, requiredSchema) } override def close(): Unit = {} @@ -468,24 +469,24 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), - new UnsafeRowDataReaderFactory(5, 10)) + override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), + new UnsafeRowInputPartitionReader(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowInputPartitionReader(start: Int, end: Int) + extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -503,7 +504,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceReader { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = + override def planInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -516,16 +517,17 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + java.util.Arrays.asList( + new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class BatchDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchInputPartitionReader(start: Int, end: Int) + extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -534,7 +536,7 @@ class BatchDataReaderFactory(start: Int, end: Int) private var current = start - override def createDataReader(): DataReader[ColumnarBatch] = this + override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this override def next(): Boolean = { i.reset() @@ -568,11 +570,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -590,14 +592,14 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) + extends InputPartition[Row] + with InputPartitionReader[Row] { assert(i.length == j.length) private var current = -1 - override def createDataReader(): DataReader[Row] = this + override def createPartitionReader(): InputPartitionReader[Row] = this override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa321359..694bb3b95b0f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,9 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVDataReaderFactory( + new SimpleCSVInputPartitionReader( f.getPath.toUri.toString, - serializableConf): DataReaderFactory[Row] + serializableConf): InputPartition[Row] }.toList.asJava } else { Collections.emptyList() @@ -156,14 +156,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) - extends DataReaderFactory[Row] with DataReader[Row] { +class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) + extends InputPartition[Row] with InputPartitionReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createDataReader(): DataReader[Row] = { + override def createPartitionReader(): InputPartitionReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 57986999bf861..dcf6cb5d609ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { clock.waitTillTime(1350) - super.createUnsafeRowReaderFactories() + super.planUnsafeInputPartitions() } } } @@ -290,13 +290,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AdvanceManualClock(100), // time = 1150 to unblock getEndOffset AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 + // will block on planInputPartitions that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1350), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions AssertClockTime(1350), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e755625d09e0f..f47d3ec8ae025 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -72,8 +72,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new DataReaderFactory[UnsafeRow] { - override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + val factory = new InputPartition[UnsafeRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { var index = -1 var curr: UnsafeRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index af4618bed5456..c1a28b9bc75ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} @@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { + def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From e3d434994733ae16e7e1424fb6de2d22b1a13f99 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 10 May 2018 13:36:52 +0800 Subject: [PATCH 465/593] [SPARK-22279][SQL] Enable `convertMetastoreOrc` by default ## What changes were proposed in this pull request? We reverted `spark.sql.hive.convertMetastoreOrc` at https://github.com/apache/spark/pull/20536 because we should not ignore the table-specific compression conf. Now, it's resolved via [SPARK-23355](https://github.com/apache/spark/commit/8aa1d7b0ede5115297541d29eab4ce5f4fe905cb). ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21186 from dongjoon-hyun/SPARK-24112. --- docs/sql-programming-guide.md | 3 ++- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e8946e424237..3f79ed6422205 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1017,7 +1017,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also - + @@ -1813,6 +1813,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 10c9603745379..bb134bbe68bd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -105,11 +105,10 @@ private[spark] object HiveUtils extends Logging { .createWithDefault(false) val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") - .internal() .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 94d671448240c8f6da11d2523ba9e4ae5b56a410 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 10 May 2018 20:38:52 +0900 Subject: [PATCH 466/593] [SPARK-23907][SQL] Add regr_* functions ## What changes were proposed in this pull request? The PR introduces regr_slope, regr_intercept, regr_r2, regr_sxx, regr_syy, regr_sxy, regr_avgx, regr_avgy, regr_count. The implementation of this functions mirrors Hive's one in HIVE-15978. ## How was this patch tested? added UT (values compared with Hive) Author: Marco Gaido Closes #21054 from mgaido91/SPARK-23907. --- .../catalyst/analysis/FunctionRegistry.scala | 9 + .../expressions/aggregate/Average.scala | 47 +++-- .../aggregate/CentralMomentAgg.scala | 60 +++--- .../catalyst/expressions/aggregate/Corr.scala | 52 ++--- .../expressions/aggregate/Count.scala | 47 +++-- .../expressions/aggregate/Covariance.scala | 36 ++-- .../expressions/aggregate/regression.scala | 190 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 172 ++++++++++++++++ .../sql-tests/inputs/udaf-regrfunctions.sql | 56 ++++++ .../results/udaf-regrfunctions.sql.out | 93 +++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 71 ++++++- 11 files changed, 721 insertions(+), 112 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 87b0911e150c5..087d000a9db70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,15 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[RegrCount]("regr_count"), + expression[RegrSXX]("regr_sxx"), + expression[RegrSYY]("regr_syy"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), + expression[RegrSXY]("regr_sxy"), + expression[RegrSlope]("regr_slope"), + expression[RegrR2]("regr_r2"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbfc36058..a133bc2361eb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil +abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override def nullable: Boolean = true - // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") - private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) @@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right @@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _ => Cast(sum, resultType) / Cast(count, resultType) } + + protected def updateExpressionsDef: Seq[Expression] = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override lazy val updateExpressions = updateExpressionsDef +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) + extends AverageLike(child) with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 572d29caf5bc9..6bbb083f1e18e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression) override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val delta = child - avg - val deltaN = delta / newN - val newAvg = avg + deltaN - val newM2 = m2 + delta * (delta - deltaN) - - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val newM3 = if (momentOrder >= 3) { - m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) - } else { - Literal(0.0) - } - val newM4 = if (momentOrder >= 4) { - m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + - delta * (delta * delta2 - deltaN * deltaN2) - } else { - Literal(0.0) - } - - trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) - )) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression) trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) + } } // Compute the population standard deviation of a column diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 95a4a0d5af634..3cdef72c1f2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ /** - * Compute Pearson correlation between two expressions. + * Base class for computing Pearson correlation between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") -// scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef + + override val mergeExpressions: Seq[Expression] = { + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) + } + + protected def updateExpressionsDef: Seq[Expression] = { val newN = n + Literal(1.0) val dx = x - xAvg val dxN = dx / newN @@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression) If(isNull, yMk, newYMk) ) } +} - override val mergeExpressions: Seq[Expression] = { - - val n1 = n.left - val n2 = n.right - val newN = n1 + n2 - val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) - val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) - val newXAvg = xAvg.left + dxN * n2 - val newYAvg = yAvg.left + dyN * n2 - val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 - val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 - Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) - } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") +// scalastyle:on line.size.limit +case class Corr(x: Expression, y: Expression) + extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1990f2f2f0722..40582d0abd762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate { - +/** + * Base class for all counting aggregators. + */ +abstract class CountLike extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType - private lazy val count = AttributeReference("count", LongType, nullable = false)() + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends CountLike { + override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) if (nullableChildren.isEmpty) { @@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index fc6c34baafdd1..72a7c62b328ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) - override lazy val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val dx = x - xAvg - val dy = y - yAvg - val dyN = dy / newN - val newXAvg = xAvg + dx / newN - val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) - - val isNull = IsNull(x) || IsNull(y) - Seq( - If(isNull, n, newN), - If(isNull, xAvg, newXAvg), - If(isNull, yAvg, newYAvg), - If(isNull, ck, newCk) - ) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression) Seq(newN, newXAvg, newYAvg, newCk) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala new file mode 100644 index 0000000000000..d8f4505588ff2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{AbstractDataType, DoubleType} + +/** + * Base trait for all regression functions. + */ +trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { + def y: Expression + def x: Expression + + override def children: Seq[Expression] = Seq(y, x) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { + assert(aggBufferAttributes.length == exprs.length) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + exprs + } else { + exprs.zip(aggBufferAttributes).map { case (e, a) => + If(nullableChildren.map(IsNull).reduce(Or), a, e) + } + } + } +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", + since = "2.4.0") +case class RegrCount(y: Expression, x: Expression) + extends CountLike with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) + + override def prettyName: String = "regr_count" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSXX(y: Expression, x: Expression) + extends CentralMomentAgg(x) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_sxx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSYY(y: Expression, x: Expression) + extends CentralMomentAgg(y) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_syy" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgX(y: Expression, x: Expression) + extends AverageLike(x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgY(y: Expression, x: Expression) + extends AverageLike(y) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgy" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSXY(y: Expression, x: Expression) + extends Covariance(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), ck) + } + + override def prettyName: String = "regr_sxy" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSlope(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) + } + + override def prettyName: String = "regr_slope" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrR2(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) + } + + override def prettyName: String = "regr_r2" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + xAvg - (ck / yMk) * yAvg) + } + + override def prettyName: String = "regr_intercept" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8f9e4ae18b3f1..28cf705eb9700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -775,6 +775,178 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: Column, x: Column): Column = withAggregateFunction { + RegrCount(y.expr, x.expr) + } + + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { + RegrSXX(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: Column, x: Column): Column = withAggregateFunction { + RegrSYY(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgX(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { + RegrSXY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: Column, x: Column): Column = withAggregateFunction { + RegrSlope(y.expr, x.expr) + } + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: Column, x: Column): Column = withAggregateFunction { + RegrR2(y.expr, x.expr) + } + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { + RegrIntercept(y.expr, x.expr) + } + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) + + + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql new file mode 100644 index 0000000000000..92c7e26e3add2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -0,0 +1,56 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x); + +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px; + + +select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out new file mode 100644 index 0000000000000..d7d009a64bf84 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px +-- !query 1 schema +struct +-- !query 1 output +1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 +2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 +3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 +4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 +5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 + + +-- !query 2 +select id, regr_count(y,x) over (partition by px) from t1 order by id +-- !query 2 schema +struct +-- !query 2 output +101 4 +102 4 +103 4 +104 4 +201 4 +202 4 +203 4 +204 4 +301 4 +302 4 +303 4 +304 4 +401 4 +402 4 +403 4 +404 4 +501 0 +502 0 +503 0 +504 0 +601 0 +602 0 +603 0 +604 0 +701 0 +702 0 +703 0 +704 0 +800 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7776e36702ad..4337fb2290fbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ + val absTol = 1e-8 + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -686,4 +687,72 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23907: regression functions") { + val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") + val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) + .toDF("a", "b") + val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( + (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") + checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) + checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) + checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), + Row(null), absTol) + + + checkAggregatesWithTol(correlatedData.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), + absTol) + checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), + absTol) + } } From f4fed0512101a67d9dae50ace11d3940b910e05e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 10 May 2018 09:44:49 -0700 Subject: [PATCH 467/593] [SPARK-24171] Adding a note for non-deterministic functions ## What changes were proposed in this pull request? I propose to add a clear statement for functions like `collect_list()` about non-deterministic behavior of such functions. The behavior must be taken into account by user while creating and running queries. Author: Maxim Gekk Closes #21228 from MaxGekk/deterministic-comments. --- R/pkg/R/functions.R | 11 +++++ python/pyspark/sql/functions.py | 18 ++++++++ .../MonotonicallyIncreasingID.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/randomExpressions.scala | 8 ++-- .../org/apache/spark/sql/functions.scala | 46 +++++++++++++++++-- 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0ec99d19e21e4..04d0e4620b28a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -805,6 +805,8 @@ setMethod("factorial", #' #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param na.rm a logical value indicating whether NA values should be stripped #' before the computation proceeds. @@ -948,6 +950,8 @@ setMethod("kurtosis", #' #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param x column to compute on. #' @param na.rm a logical value indicating whether NA values should be stripped @@ -1201,6 +1205,7 @@ setMethod("minute", #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. #' The method should be used with no argument. +#' Note: the function is non-deterministic because its result depends on partition IDs. #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method @@ -2584,6 +2589,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @details #' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) #' samples from U[0.0, 1.0]. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2612,6 +2618,7 @@ setMethod("rand", signature(seed = "numeric"), #' @details #' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples #' from the standard normal distribution. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -3188,6 +3195,8 @@ setMethod("create_map", #' @details #' \code{collect_list}: Creates a list of objects with duplicates. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method @@ -3207,6 +3216,8 @@ setMethod("collect_list", #' @details #' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ac3c79766702c..f5a584152b4f6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -152,6 +152,9 @@ def _(): _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_list('age')).collect() [Row(collect_list(age)=[2, 5, 5])] @@ -159,6 +162,9 @@ def _(): _collect_set_doc = """ Aggregate function: returns a set of objects with duplicate elements eliminated. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] @@ -401,6 +407,9 @@ def first(col, ignorenulls=False): The function by default returns the first values it sees. It will return the first non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows which + may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) @@ -489,6 +498,9 @@ def last(col, ignorenulls=False): The function by default returns the last values it sees. It will return the last non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows + which may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) @@ -504,6 +516,8 @@ def monotonically_increasing_id(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + .. note:: The function is non-deterministic because its result depends on partition IDs. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -536,6 +550,8 @@ def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('rand', rand(seed=42) * 3).collect() [Row(age=2, name=u'Alice', rand=1.1568609015300986), Row(age=5, name=u'Bob', rand=1.403379671529166)] @@ -554,6 +570,8 @@ def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('randn', randn(seed=42)).collect() [Row(age=2, name=u'Alice', randn=-0.7556247885860078), Row(age=5, name=u'Bob', randn=-0.0861619008451133)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ad1e7bdb31987..9f0779642271d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.types.{DataType, LongType} puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + The function is non-deterministic because its result depends on partition IDs. """) case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7eda65a867028..b7834696cafc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -117,12 +117,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.", + usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", examples = """ Examples: > SELECT _FUNC_(); 46707d92-02f4-4817-8116-a4c3b23e6266 - """) + """, + note = "The function is non-deterministic.") // scalastyle:on line.size.limit case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 70186053617f8..2653b28f6c3bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -68,7 +68,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful 0.8446490682263027 > SELECT _FUNC_(null); 0.8446490682263027 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { @@ -96,7 +97,7 @@ object Rand { /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.", + usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""", examples = """ Examples: > SELECT _FUNC_(); @@ -105,7 +106,8 @@ object Rand { 1.1164209726833079 > SELECT _FUNC_(null); 1.1164209726833079 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 28cf705eb9700..225de0051d6fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -283,6 +283,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -291,6 +294,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -299,6 +305,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -307,6 +316,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -422,6 +434,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -435,6 +450,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -448,6 +466,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -459,6 +480,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -535,6 +559,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -548,6 +575,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -561,6 +591,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -572,6 +605,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -1344,7 +1380,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1355,6 +1391,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1364,7 +1402,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1375,6 +1413,8 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1383,7 +1423,7 @@ object functions { /** * Partition ID. * - * @note This is indeterministic because it depends on data partitioning and task scheduling. + * @note This is non-deterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 From 6282fc64e32fc2f70e79ace14efd4922e4535dbb Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 10 May 2018 11:36:41 -0700 Subject: [PATCH 468/593] [SPARK-24137][K8S] Mount local directories as empty dir volumes. ## What changes were proposed in this pull request? Drastically improves performance and won't cause Spark applications to fail because they write too much data to the Docker image's specific file system. The file system's directories that back emptydir volumes are generally larger and more performant. ## How was this patch tested? Has been in use via the prototype version of Kubernetes support, but lost in the transition to here. Author: mcheah Closes #21238 from mccheah/mount-local-dirs. --- .../scala/org/apache/spark/SparkConf.scala | 5 +- .../k8s/features/LocalDirsFeatureStep.scala | 77 ++++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 10 +- .../k8s/KubernetesExecutorBuilder.scala | 9 +- .../features/LocalDirsFeatureStepSuite.scala | 111 ++++++++++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 13 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 12 +- 7 files changed, 223 insertions(+), 14 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 129956e9f9ffa..dab409572646f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -454,8 +454,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { - val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + - "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." + val msg = "Note that spark.local.dir will be overridden by the value set by " + + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" + + " in YARN)." logWarning(msg) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala new file mode 100644 index 0000000000000..70b307303d149 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.nio.file.Paths +import java.util.UUID + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class LocalDirsFeatureStep( + conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], + defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}") + extends KubernetesFeatureConfigStep { + + // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system + // property - we want to instead default to mounting an emptydir volume that doesn't already + // exist in the image. + // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already + // a bit opinionated about YARN and Mesos. + private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS")) + .orElse(conf.getOption("spark.local.dir")) + .getOrElse(defaultLocalDir) + .split(",") + + override def configurePod(pod: SparkPod): SparkPod = { + val localDirVolumes = resolvedLocalDirs + .zipWithIndex + .map { case (localDir, index) => + new VolumeBuilder() + .withName(s"spark-local-dir-${index + 1}") + .withNewEmptyDir() + .endEmptyDir() + .build() + } + val localDirVolumeMounts = localDirVolumes + .zip(resolvedLocalDirs) + .map { case (localDirVolume, localDirPath) => + new VolumeMountBuilder() + .withName(localDirVolume.getName) + .withMountPath(localDirPath) + .build() + } + val podWithLocalDirVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(localDirVolumes: _*) + .endSpec() + .build() + val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container) + .addNewEnv() + .withName("SPARK_LOCAL_DIRS") + .withValue(resolvedLocalDirs.mkString(",")) + .endEnv() + .addToVolumeMounts(localDirVolumeMounts: _*) + .build() + SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index c7579ed8cb689..10b0154466a3a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -29,14 +29,18 @@ private[spark] class KubernetesDriverBuilder( new DriverServiceFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { val baseFeatures = Seq( provideBasicStep(kubernetesConf), provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf)) + provideServiceStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 22568fe7ea3be..d8f63d57574fb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,18 +17,21 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala new file mode 100644 index 0000000000000..91e184b84b86e --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + private val defaultLocalDir = "/var/data/default-local-dir" + private var sparkConf: SparkConf = _ + private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + + before { + val realSparkConf = new SparkConf(false) + sparkConf = Mockito.spy(realSparkConf) + kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + "resource", + "app-id", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + } + + test("Resolve to default local dir if neither env nor configuration are set") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } + + test("Use configured local dirs split on comma if provided.") { + Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 2) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.pod.getSpec.getVolumes.get(1) === + new VolumeBuilder() + .withName(s"spark-local-dir-2") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 2) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath("/var/data/my-local-dir-1") + .build()) + assert(configuredPod.container.getVolumeMounts.get(1) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-2") + .withMountPath("/var/data/my-local-dir-2") + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .build()) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 161f9afe7bba9..a511d254d2175 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val CREDENTIALS_STEP_TYPE = "credentials" private val SERVICE_STEP_TYPE = "service" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( @@ -36,6 +37,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) @@ -44,7 +48,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, - _ => secretsStep) + _ => secretsStep, + _ => localDirsStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( @@ -64,7 +69,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE) + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index f5270623f8acc..9ee86b5a423a9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,20 +20,24 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, - _ => mountSecretsStep) + _ => mountSecretsStep, + _ => localDirsStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -46,7 +50,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty) - validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -63,6 +68,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE) } From 3e2600538ee477ffe3f23fba57719e035219550b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 10 May 2018 14:26:38 -0700 Subject: [PATCH 469/593] [SPARK-19181][CORE] Fixing flaky "SparkListenerSuite.local metrics" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Sometimes "SparkListenerSuite.local metrics" test fails because the average of executorDeserializeTime is too short. As squito suggested to avoid these situations in one of the task a reference introduced to an object implementing a custom Externalizable.readExternal which sleeps 1ms before returning. ## How was this patch tested? With unit tests (and checking the effect of this change to the average with a much larger sleep time). Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #21280 from attilapiros/SPARK-19181. --- .../spark/scheduler/SparkListenerSuite.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index da6ecb82c7e42..fa47a52bbbc47 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.Semaphore import scala.collection.JavaConverters._ @@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) - // just to make sure some of the tasks take a noticeable amount of time + // just to make sure some of the tasks and their deserialization take a noticeable + // amount of time + val slowDeserializable = new SlowDeserializable val w = { i: Int => if (i == 0) { Thread.sleep(100) + slowDeserializable.use() } i } @@ -583,3 +587,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar case _ => } } + +private class SlowDeserializable extends Externalizable { + + override def writeExternal(out: ObjectOutput): Unit = { } + + override def readExternal(in: ObjectInput): Unit = Thread.sleep(1) + + def use(): Unit = { } +} From d3c426a5b02abdec49ff45df12a8f11f9e473a88 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 10 May 2018 14:41:55 -0700 Subject: [PATCH 470/593] [SPARK-10878][CORE] Fix race condition when multiple clients resolves artifacts at the same time ## What changes were proposed in this pull request? When multiple clients attempt to resolve artifacts via the `--packages` parameter, they could run into race condition when they each attempt to modify the dummy `org.apache.spark-spark-submit-parent-default.xml` file created in the default ivy cache dir. This PR changes the behavior to encode UUID in the dummy module descriptor so each client will operate on a different resolution file in the ivy cache dir. In addition, this patch changes the behavior of when and which resolution files are cleaned to prevent accumulation of resolution files in the default ivy cache dir. Since this PR is a successor of #18801, close #18801. Many codes were ported from #18801. **Many efforts were put here. I think this PR should credit to Victsm .** ## How was this patch tested? added UT into `SparkSubmitUtilsSuite` Author: Kazuaki Ishizaki Closes #21251 from kiszk/SPARK-10878. --- .../org/apache/spark/deploy/SparkSubmit.scala | 42 ++++++++++++++----- .../spark/deploy/SparkSubmitUtilsSuite.scala | 15 +++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 427c797755b84..087e9c31a9c9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab import java.net.URL import java.security.PrivilegedExceptionAction import java.text.ParseException +import java.util.UUID import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -1204,7 +1205,33 @@ private[spark] object SparkSubmitUtils { /** A nice function to use in tests as well. Values are dummy strings. */ def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( - ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + // Include UUID in module name, so multiple clients resolving maven coordinate at the same time + // do not modify the same resolution file concurrently. + ModuleRevisionId.newInstance("org.apache.spark", + s"spark-submit-parent-${UUID.randomUUID.toString}", + "1.0")) + + /** + * Clear ivy resolution from current launch. The resolution file is usually at + * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml, + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties. + * Since each launch will have its own resolution files created, delete them after + * each resolution to prevent accumulation of these files in the ivy cache dir. + */ + private def clearIvyResolutionFiles( + mdId: ModuleRevisionId, + ivySettings: IvySettings, + ivyConfName: String): Unit = { + val currentResolutionFiles = Seq( + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" + ) + currentResolutionFiles.foreach { filename => + new File(ivySettings.getDefaultCache, filename).delete() + } + } /** * Resolves any dependencies that were supplied through maven coordinates @@ -1255,14 +1282,6 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor - // clear ivy resolution from previous launches. The resolution file is usually at - // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file - // leads to confusion with Ivy when the files can no longer be found at the repository - // declared in that file/ - val mdId = md.getModuleRevisionId - val previousResolution = new File(ivySettings.getDefaultCache, - s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") - if (previousResolution.exists) previousResolution.delete md.setDefaultConf(ivyConfName) @@ -1283,7 +1302,10 @@ private[spark] object SparkSubmitUtils { packagesDirectory.getAbsolutePath + File.separator + "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val paths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val mdId = md.getModuleRevisionId + clearIvyResolutionFiles(mdId, ivySettings, ivyConfName) + paths } finally { System.setOut(sysOut) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index eb8c203ae7751..a0f09891787e0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } } + + test("SPARK-10878: test resolution files cleaned after resolving artifact") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + + IvyTestUtils.withRepository(main, None, None) { repo => + val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath)) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + ivySettings, + isTest = true) + val r = """.*org.apache.spark-spark-submit-parent-.*""".r + assert(!ivySettings.getDefaultCache.listFiles.map(_.getName) + .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned") + } + } } From a4206d58e05ab9ed6f01fee57e18dee65cbc4efc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 09:01:40 +0800 Subject: [PATCH 471/593] [SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/20136 . #20136 didn't really work because in the test, we are using local backend, which shares the driver side `SparkEnv`, so `SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER` doesn't work. This PR changes the check to `TaskContext.get != null`, and move the check to `SQLConf.get`, and fix all the places that violate this check: * `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. https://github.com/apache/spark/pull/21223 merged * `DataType#sameType` may be executed in executor side, for things like json schema inference, so we can't call `conf.caseSensitiveAnalysis` there. This contributes to most of the code changes, as we need to add `caseSensitive` parameter to a lot of methods. * `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't can't call `conf.parquetFilterPushDownDate` there. https://github.com/apache/spark/pull/21224 merged * `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. https://github.com/apache/spark/pull/21225 merged * `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. https://github.com/apache/spark/pull/21226 merged ## How was this patch tested? existing test Author: Wenchen Fan Closes #21190 from cloud-fan/minor. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++++-------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 ++- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 +++-- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 188 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..94b0561529e71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -260,7 +261,9 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { + val widerType = TypeCoercion.findWiderTypeForTwo( + dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) + if (widerType.isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index f2df3e132629f..4eb6e642b1c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,7 +83,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { + val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( + inputTypes, conf.caseSensitiveAnalysis) + val tpe = wideType.getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..a7ba201509b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes :: + WidenSetOperationTypes(conf) :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion :: + FunctionArgumentConversion(conf) :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion :: - IfCoercion :: + CaseWhenCoercion(conf) :: + IfCoercion(conf) :: StackCoercion :: Division :: - new ImplicitTypeCasts(conf) :: + ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,7 +83,10 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + def findTightestCommonType( + left: DataType, + right: DataType, + caseSensitive: Boolean): Option[DataType] = (left, right) match { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -102,22 +105,32 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => + val isSameType = if (caseSensitive) { + DataType.equalsIgnoreNullability(t1, t2) + } else { + DataType.equalsIgnoreCaseAndNullability(t1, t2) + } + + if (isSameType) { + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since t1 is same type of t2, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + } else { + None + } case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) + val keyType = findTightestCommonType(kt1, kt2, caseSensitive) + val valueType = findTightestCommonType(vt1, vt2, caseSensitive) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -172,13 +185,14 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) + def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { + findTightestCommonType(t1, t2, caseSensitive) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2, caseSensitive) + .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -193,7 +207,8 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + private def findWiderCommonType( + types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -201,7 +216,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) + case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) case _ => None }) } @@ -213,20 +228,22 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) + t2: DataType, + caseSensitive: Boolean): Option[DataType] = { + findTightestCommonType(t1, t2, caseSensitive) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) + findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + def findWiderTypeWithoutStringPromotion( + types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) case None => None }) } @@ -279,29 +296,32 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenSetOperationTypes extends Rule[LogicalPlan] { + case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + val newChildren: Seq[LogicalPlan] = + buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + val newChildren: Seq[LogicalPlan] = + buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes( + children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -316,18 +336,19 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + castedTypes: mutable.Queue[DataType], + caseSensitive: Boolean): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes) + getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) } } @@ -432,7 +453,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType)) + .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) } // The number of columns/expressions must match between LHS and RHS of an @@ -461,7 +482,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { + findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -515,7 +536,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - object FunctionArgumentConversion extends TypeCoercionRule { + case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -523,7 +544,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -531,7 +552,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -542,7 +563,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -552,7 +573,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -580,7 +601,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -590,14 +611,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -637,11 +658,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CaseWhenCoercion extends TypeCoercionRule { + case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -668,16 +689,17 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - object IfCoercion extends TypeCoercionRule { + case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { + widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -776,12 +798,11 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -804,17 +825,18 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType).map { commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { + commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..0b1965c438e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,7 +107,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") + } + confGetter.get()() + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1274,12 +1280,6 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..4ee12db9c10ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,11 +81,7 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - if (SQLConf.get.caseSensitiveAnalysis) { - DataType.equalsIgnoreNullability(this, other) - } else { - DataType.equalsIgnoreCaseAndNullability(this, other) - } + DataType.equalsIgnoreNullability(this, other) /** * Returns the same data type but set all nullability fields are true @@ -218,7 +214,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..f73e045685ee1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType) => Option[DataType], + widenFunc: (DataType, DataType, Boolean) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2) + var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1) + found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion + val rule = TypeCoercion.FunctionArgumentConversion(conf) val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion + val rule = TypeCoercion.IfCoercion(conf) val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(FunctionArgumentConversion(conf), Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f9a24806953e6..1edf27619ad7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -521,6 +522,8 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) + case (t1, t2) => + TypeCoercion.findWiderTypeForTwo( + t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e0424b7478122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -44,6 +45,7 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord + val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -53,7 +55,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions)) + Some(inferField(parser, configOptions, caseSensitive)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -68,7 +70,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -98,14 +100,15 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + private def inferField( + parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions) + inferField(parser, configOptions, caseSensitive) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -122,7 +125,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions), + inferField(parser, configOptions, caseSensitive), nullable = true) } val fields: Array[StructField] = builder.result() @@ -137,7 +140,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions)) + elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) } ArrayType(elementType) @@ -243,13 +246,14 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode): (DataType, DataType) => DataType = { + parseMode: ParseMode, + caseSensitive: Boolean): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -259,7 +263,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2) + case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -267,8 +271,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { - TypeCoercion.findTightestCommonType(t1, t2).getOrElse { + def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { + TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -303,7 +307,8 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + val dataType = compatibleType( + fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -326,15 +331,17 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + ArrayType( + compatibleType(elementType1, elementType2, caseSensitive), + containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2, caseSensitive) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2), caseSensitive) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..34d23ee53220d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2) + var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1) + actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 75cf369c742e7c7b68f384d123447c97be95c9f0 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 11 May 2018 09:05:35 +0800 Subject: [PATCH 472/593] [SPARK-24197][SPARKR][SQL] Adding array_sort function to SparkR ## What changes were proposed in this pull request? The PR adds array_sort function to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Example ``` > df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) > head(collect(select(df, array_sort(df[[1]])))) ``` Result: ``` array_sort(_1) 1 1, 2, 3, NA 2 4, 5, 6, NA, NA ``` Author: Marek Novotny Closes #21294 from mn-mikke/SPARK-24197. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 21 ++++++++++++++++++--- R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++++---- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 8cd00352d1956..5f8209689a559 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,6 +204,7 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_sort", "asc", "ascii", "asin", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 04d0e4620b28a..1f97054443e1b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -207,7 +207,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -3043,6 +3043,20 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array +#' must be orderable. NA elements will be placed at the end of the returned array. +#' +#' @rdname column_collection_functions +#' @aliases array_sort array_sort,Column-method +#' @note array_sort since 2.4.0 +setMethod("array_sort", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc) + column(jc) + }) + #' @details #' \code{flatten}: Transforms an array of arrays into a single array. #' @@ -3125,8 +3139,9 @@ setMethod("size", }) #' @details -#' \code{sort_array}: Sorts the input array in ascending or descending order according -#' to the natural ordering of the array elements. +#' \code{sort_array}: Sorts the input array in ascending or descending order according to +#' the natural ordering of the array elements. NA elements will be placed at the beginning of +#' the returned array in ascending order or at the end of the returned array in descending order. #' #' @rdname column_collection_functions #' @param asc a logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4ef12d19b3575..5faa51eef3abd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,6 +769,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 43725e0ebd3bf..b8bfded0ebf2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,8 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position(), element_at() - # and sort_array() + # Test array_contains(), array_max(), array_min(), array_position() and element_at() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1497,10 +1496,16 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + # Test array_sort() and sort_array() + df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) + + result <- collect(select(df, array_sort(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA))) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] - expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA))) result <- collect(select(df, sort_array(df[[1]])))[[1]] - expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) # Test flattern df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), From 54032682b910dc5089af27d2c7b6efe55700f034 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 11 May 2018 17:40:35 +0800 Subject: [PATCH 473/593] [SPARK-24182][YARN] Improve error message when client AM fails. Instead of always throwing a generic exception when the AM fails, print a generic error and throw the exception with the YARN diagnostics containing the reason for the failure. There was an issue with YARN sometimes providing a generic diagnostic message, even though the AM provides a failure reason when unregistering. That was happening because the AM was registering too late, and if errors happened before the registration, YARN would just create a generic "ExitCodeException" which wasn't very helpful. Since most errors in this path are a result of not being able to connect to the driver, this change modifies the AM registration a bit so that the AM is registered before the connection to the driver is established. That way, errors are properly propagated through YARN back to the driver. As part of that, I also removed the code that retried connections to the driver from the client AM. At that point, the driver should already be up and waiting for connections, so it's unlikely that retrying would help - and in case it does, that means a flaky network, which would mean problems would probably show up again. The effect of that is that connection-related errors are reported back to the driver much faster now (through the YARN report). One thing to note is that there seems to be a race on the YARN side that causes a report to be sent to the client without the corresponding diagnostics string from the AM; the diagnostics are available later from the RM web page. For that reason, the generic error messages are kept in the Spark scheduler code, to help guide users to a way of debugging their failure. Also of note is that if YARN's max attempts configuration is lower than Spark's, Spark will not unregister the AM with a proper diagnostics message. Unfortunately there seems to be no way to unregister the AM and still allow further re-attempts to happen. Testing: - existing unit tests - some of our integration tests - hardcoded an invalid driver address in the code and verified the error in the shell. e.g. ``` scala> 18/05/04 15:09:34 ERROR cluster.YarnClientSchedulerBackend: YARN application has exited unexpectedly with state FAILED! Check the YARN application logs for more details. 18/05/04 15:09:34 ERROR cluster.YarnClientSchedulerBackend: Diagnostics message: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: Caused by: java.io.IOException: Failed to connect to localhost/127.0.0.1:1234 ``` Author: Marcelo Vanzin Closes #21243 from vanzin/SPARK-24182. --- docs/running-on-yarn.md | 5 +- .../spark/deploy/yarn/ApplicationMaster.scala | 103 +++++++----------- .../org/apache/spark/deploy/yarn/Client.scala | 43 +++++--- .../spark/deploy/yarn/YarnRMClient.scala | 29 +++-- .../cluster/YarnClientSchedulerBackend.scala | 35 ++++-- 5 files changed, 112 insertions(+), 103 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ceda8a3ae2403..c9e68c3bfd056 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -133,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 595077e7e809f..3d6ee50b070a3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -346,7 +346,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends synchronized { if (!finished) { val inShutdown = ShutdownHookManager.inShutdown() - if (registered) { + if (registered || !isClusterMode) { exitCode = code finalStatus = status } else { @@ -389,37 +389,40 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def registerAM( + host: String, + port: Int, _sparkConf: SparkConf, - _rpcEnv: RpcEnv, - driverRef: RpcEndpointRef, - uiAddress: Option[String]) = { + uiAddress: Option[String]): Unit = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = ApplicationMaster .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) - val driverUrl = RpcEndpointAddress( - _sparkConf.get("spark.driver.host"), - _sparkConf.get("spark.driver.port").toInt, + client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress) + registered = true + } + + private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = { + val appId = client.getAttemptId().getApplicationId().toString() + val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString // Before we initialize the allocator, let's log the information about how executors will // be run up front, to avoid printing this out for every single executor being launched. // Use placeholders for information that changes such as executor IDs. logInfo { - val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt - val executorCores = sparkConf.get(EXECUTOR_CORES) - val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "", + val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt + val executorCores = _sparkConf.get(EXECUTOR_CORES) + val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "", "", executorMemory, executorCores, appId, securityMgr, localResources) dummyRunner.launchContextDebugInfo() } - allocator = client.register(driverUrl, - driverRef, + allocator = client.createAllocator( yarnConf, _sparkConf, - uiAddress, - historyAddress, + driverUrl, + driverRef, securityMgr, localResources) @@ -434,15 +437,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends reporterThread = launchReporterThread() } - /** - * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend. - */ - private def createSchedulerRef(host: String, port: String): RpcEndpointRef = { - rpcEnv.setupEndpointRef( - RpcAddress(host, port.toInt), - YarnSchedulerBackend.ENDPOINT_NAME) - } - private def runDriver(): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -456,11 +450,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Duration(totalWaitTime, TimeUnit.MILLISECONDS)) if (sc != null) { rpcEnv = sc.env.rpcEnv - val driverRef = createSchedulerRef( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port")) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl)) - registered = true + + val userConf = sc.getConf + val host = userConf.get("spark.driver.host") + val port = userConf.get("spark.driver.port").toInt + registerAM(host, port, userConf, sc.ui.map(_.webUrl)) + + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(host, port), + YarnSchedulerBackend.ENDPOINT_NAME) + createAllocator(driverRef, userConf) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -486,10 +485,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val amCores = sparkConf.get(AM_CORES) rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, amCores, true) - val driverRef = waitForSparkDriver() + + // The client-mode AM doesn't listen for incoming connections, so report an invalid port. + registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress")) + + // The driver should be up and listening, so unlike cluster mode, just try to connect to it + // with no waiting or retrying. + val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0)) + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(driverHost, driverPort), + YarnSchedulerBackend.ENDPOINT_NAME) addAmIpFilter(Some(driverRef)) - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress")) - registered = true + createAllocator(driverRef, sparkConf) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -600,40 +607,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def waitForSparkDriver(): RpcEndpointRef = { - logInfo("Waiting for Spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - - // Spark driver should already be up since it launched us, but we don't want to - // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME) - val deadline = System.currentTimeMillis + totalWaitTimeMs - - while (!driverUp && !finished && System.currentTimeMillis < deadline) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at %s:%s, retrying ...". - format(driverHost, driverPort)) - Thread.sleep(100L) - } - } - - if (!driverUp) { - throw new SparkException("Failed to connect to driver!") - } - - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - createSchedulerRef(driverHost, driverPort.toString) - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter(driver: Option[RpcEndpointRef]) = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 134b3e5fef11a..7225ff03dc34e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1019,8 +1019,7 @@ private[spark] class Client( appId: ApplicationId, returnOnRunning: Boolean = false, logApplicationReport: Boolean = true, - interval: Long = sparkConf.get(REPORT_INTERVAL)): - (YarnApplicationState, FinalApplicationStatus) = { + interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = { var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) @@ -1031,11 +1030,13 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") cleanupStagingDir(appId) - return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None) case NonFatal(e) => - logError(s"Failed to contact YARN for application $appId.", e) + val msg = s"Failed to contact YARN for application $appId." + logError(msg, e) // Don't necessarily clean up staging dir because status is unknown - return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) + return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED, + Some(msg)) } val state = report.getYarnApplicationState @@ -1073,14 +1074,14 @@ private[spark] class Client( } if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { cleanupStagingDir(appId) - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } if (returnOnRunning && state == YarnApplicationState.RUNNING) { - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } lastState = state @@ -1129,16 +1130,17 @@ private[spark] class Client( throw new SparkException(s"Application $appId finished with status: $state") } } else { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { + val YarnAppReport(appState, finalState, diags) = monitorApplication(appId) + if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) { + diags.foreach { err => + logError(s"Application diagnostics message: $err") + } throw new SparkException(s"Application $appId finished with failed status") } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { + if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) { throw new SparkException(s"Application $appId is killed") } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + if (finalState == FinalApplicationStatus.UNDEFINED) { throw new SparkException(s"The final status of application $appId is undefined") } } @@ -1477,6 +1479,12 @@ private object Client extends Logging { uri.startsWith(s"$LOCAL_SCHEME:") } + def createAppReport(report: ApplicationReport): YarnAppReport = { + val diags = report.getDiagnostics() + val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None + YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) + } + } private[spark] class YarnClusterApplication extends SparkApplication { @@ -1491,3 +1499,8 @@ private[spark] class YarnClusterApplication extends SparkApplication { } } + +private[spark] case class YarnAppReport( + appState: YarnApplicationState, + finalState: FinalApplicationStatus, + diagnostics: Option[String]) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 17234b120ae13..b59dcf158d87c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -42,23 +42,20 @@ private[spark] class YarnRMClient extends Logging { /** * Registers the application master with the RM. * + * @param driverHost Host name where driver is running. + * @param driverPort Port where driver is listening. * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. - * @param securityMgr The security manager. - * @param localResources Map with information about files distributed via YARN's cache. */ def register( - driverUrl: String, - driverRef: RpcEndpointRef, + driverHost: String, + driverPort: Int, conf: YarnConfiguration, sparkConf: SparkConf, uiAddress: Option[String], - uiHistoryAddress: String, - securityMgr: SecurityManager, - localResources: Map[String, LocalResource] - ): YarnAllocator = { + uiHistoryAddress: String): Unit = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) amClient.start() @@ -70,10 +67,19 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, - trackingUrl) + amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl) registered = true } + } + + def createAllocator( + conf: YarnConfiguration, + sparkConf: SparkConf, + driverUrl: String, + driverRef: RpcEndpointRef, + securityMgr: SecurityManager, + localResources: Map[String, LocalResource]): YarnAllocator = { + require(registered, "Must register AM before creating allocator.") new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) } @@ -88,6 +94,9 @@ private[spark] class YarnRMClient extends Logging { if (registered) { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } + if (amClient != null) { + amClient.stop() + } } /** Returns the attempt ID. */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 06e54a2eaf95a..f1a8df00f9c5b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -75,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend( val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL) assert(client != null && appId.isDefined, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true, - interval = monitorInterval) // blocking + val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get, + returnOnRunning = true, interval = monitorInterval) if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application has already ended! " + - "It might have been killed or unable to launch application master.") + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + val genericMessage = "The YARN application has already ended! " + + "It might have been killed or the Application Master may have failed to start. " + + "Check the YARN application logs for more details." + val exceptionMsg = diags match { + case Some(msg) => + logError(genericMessage) + msg + + case None => + genericMessage + } + throw new SparkException(exceptionMsg) } if (state == YarnApplicationState.RUNNING) { logInfo(s"Application ${appId.get} has started running.") @@ -100,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") + val YarnAppReport(_, state, diags) = + client.monitorApplication(appId.get, logApplicationReport = true) + logError(s"YARN application has exited unexpectedly with state $state! " + + "Check the YARN application logs for more details.") + diags.foreach { err => + logError(s"Diagnostics message: $err") + } allowInterrupt = false sc.stop() } catch { @@ -124,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread - t.setName("Yarn application state monitor") + t.setName("YARN application state monitor") t.setDaemon(true) t } From 928845a42230a2c0a318011002a54ad871468b2e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 10:00:28 -0700 Subject: [PATCH 474/593] [SPARK-24172][SQL] we should not apply operator pushdown to data source v2 many times ## What changes were proposed in this pull request? In `PushDownOperatorsToDataSource`, we use `transformUp` to match `PhysicalOperation` and apply pushdown. This is problematic if we have multiple `Filter` and `Project` above the data source v2 relation. e.g. for a query ``` Project Filter DataSourceV2Relation ``` The pattern match will be triggered twice and we will do operator pushdown twice. This is unnecessary, we can use `mapChildren` to only apply pushdown once. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21230 from cloud-fan/step2. --- .../v2/PushDownOperatorsToDataSource.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 9293d4f831bff..e894f8afd6762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -23,17 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project import org.apache.spark.sql.catalyst.rules.Rule object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply( - plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => - // merge the filters - val filters = relation.filters match { - case Some(existing) => - existing ++ newFilters - case _ => - newFilters - } + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + assert(relation.filters.isEmpty, "data source v2 should do push down only once.") val projectAttrs = project.map(_.toAttribute) val projectSet = AttributeSet(project.flatMap(_.references)) @@ -67,5 +60,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { } else { filtered } + + case other => other.mapChildren(apply) } } From 92f6f52ff0ce47e046656ca8bed7d7bfbbb42dcb Mon Sep 17 00:00:00 2001 From: aditkumar Date: Fri, 11 May 2018 14:42:23 -0500 Subject: [PATCH 475/593] [MINOR][DOCS] Documenting months_between direction ## What changes were proposed in this pull request? It's useful to know what relationship between date1 and date2 results in a positive number. Author: aditkumar Author: Adit Kumar Closes #20787 from aditkumar/master. --- R/pkg/R/functions.R | 6 +++++- python/pyspark/sql/functions.py | 7 +++++-- .../catalyst/expressions/datetimeExpressions.scala | 14 +++++++++++--- .../spark/sql/catalyst/util/DateTimeUtils.scala | 8 ++++---- .../scala/org/apache/spark/sql/functions.scala | 7 ++++++- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 1f97054443e1b..4964594284aa0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1912,6 +1912,7 @@ setMethod("atan2", signature(y = "Column"), #' @details #' \code{datediff}: Returns the number of days from \code{y} to \code{x}. +#' If \code{y} is later than \code{x} then the result is positive. #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method @@ -1971,7 +1972,10 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} +#' are on the same day of month, or both are the last day of month, time of day will be ignored. +#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5a584152b4f6..b62748e9a2d6c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1108,8 +1108,11 @@ def add_months(start, months): @since(1.5) def months_between(date1, date2, roundOff=True): """ - Returns the number of months between date1 and date2. - Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + If date1 and date2 are on the same day of month, or both are the last day of month, + returns an integer (time of day will be ignored). + The result is rounded off to 8 digits unless `roundOff` is set to `False`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 76aa61415a11f..03422fecb3209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1194,13 +1194,21 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } /** - * Returns number of months between dates date1 and date2. + * Returns number of months between times `timestamp1` and `timestamp2`. + * If `timestamp1` is later than `timestamp2`, then the result is positive. + * If `timestamp1` and `timestamp2` are on the same day of month, or both + * are the last day of month, time of day will be ignored. Otherwise, the + * difference is calculated based on 31 days per month, and rounded to + * 8 digits unless roundOff=false. */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. - The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", + _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result + is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both + are the last day of month, time of day will be ignored. Otherwise, the difference is + calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false. + """, examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index e646da0659e85..80f15053005ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -885,13 +885,13 @@ object DateTimeUtils { /** * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. + * microseconds since 1.1.1970. If time1 is later than time2, the result is positive. * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). + * If time1 and time2 are on the same day of month, or both are the last day of month, + * returns, time of day will be ignored. * * Otherwise, the difference is calculated based on 31 days per month. - * If `roundOff` is set to true, the result is rounded to 8 decimal places. + * The result is rounded to 8 decimal places if `roundOff` is set to true. */ def monthsBetween( time1: SQLTimestamp, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 225de0051d6fa..e7f866ddca681 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2903,7 +2903,12 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. - * The result is rounded off to 8 digits. + * If `date1` is later than `date2`, then the result is positive. + * If `date1` and `date2` are on the same day of month, or both are the last day of month, + * time of day will be ignored. + * + * Otherwise, the difference is calculated based on 31 days per month, and rounded to + * 8 digits. * @group datetime_funcs * @since 1.5.0 */ From f27a035daf705766d3445e5c6a99867c11c552b0 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 11 May 2018 17:00:51 -0700 Subject: [PATCH 476/593] [SPARKR] Require Java 8 for SparkR This change updates the SystemRequirements and also includes a runtime check if the JVM is being launched by R. The runtime check is done by querying `java -version` ## How was this patch tested? Tested on a Mac and Windows machine Author: Shivaram Venkataraman Closes #21278 from shivaram/sparkr-skip-solaris. --- R/pkg/DESCRIPTION | 1 + R/pkg/R/client.R | 35 +++++++++++++++++++++++++++++++++++ R/pkg/R/sparkR.R | 1 + R/pkg/R/utils.R | 4 ++-- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f16..f52d785e05cdd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html +SystemRequirements: Java (== 8) Depends: R (>= 3.0), methods diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 7244cc9f9e38e..e9295e05872bd 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -60,6 +60,40 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack combinedArgs } +checkJavaVersion <- function() { + javaBin <- "java" + javaHome <- Sys.getenv("JAVA_HOME") + javaReqs <- utils::packageDescription(utils::packageName(), fields=c("SystemRequirements")) + sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) + if (javaHome != "") { + javaBin <- file.path(javaHome, "bin", javaBin) + } + + # If java is missing from PATH, we get an error in Unix and a warning in Windows + javaVersionOut <- tryCatch( + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) + javaVersionFilter <- Filter( + function(x) { + grepl("java version", x) + }, javaVersionOut) + + javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] + # javaVersionStr is of the form 1.8.0_92. + # Extract 8 from it to compare to sparkJavaVersion + javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) + if (javaVersionNum != sparkJavaVersion) { + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) + } +} + launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { @@ -67,6 +101,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } else { sparkSubmitBin <- sparkSubmitBinName } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(launchScript(sparkSubmitBin, combinedArgs)) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 38ee79477996f..d6a2d08f9c218 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,6 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) + checkJavaVersion() launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index f1b5ecaa017df..c3501977e64bc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -746,7 +746,7 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, wait = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) @@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) { # stdout = F means discard output # stdout = "" means to its console (default) # Note that the console of this child process might not be the same as the running R process. - system2(script, combinedArgs, stdout = "", wait = wait) + system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr) } } From e3dabdf6ef210fb9f4337e305feb9c4983a57350 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 12 May 2018 12:15:36 +0800 Subject: [PATCH 477/593] [SPARK-23907] Removes regr_* functions in functions.scala ## What changes were proposed in this pull request? This patch removes the various regr_* functions in functions.scala. They are so uncommon that I don't think they deserve real estate in functions.scala. We can consider adding them later if more users need them. ## How was this patch tested? Removed the associated test case as well. Author: Reynold Xin Closes #21309 from rxin/SPARK-23907. --- .../org/apache/spark/sql/functions.scala | 171 ------------------ .../spark/sql/DataFrameAggregateSuite.scala | 68 ------- 2 files changed, 239 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e7f866ddca681..3c9ace407a58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -811,177 +811,6 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: Column, x: Column): Column = withAggregateFunction { - RegrCount(y.expr, x.expr) - } - - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { - RegrSXX(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: Column, x: Column): Column = withAggregateFunction { - RegrSYY(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgX(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { - RegrSXY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: Column, x: Column): Column = withAggregateFunction { - RegrSlope(y.expr, x.expr) - } - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: Column, x: Column): Column = withAggregateFunction { - RegrR2(y.expr, x.expr) - } - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { - RegrIntercept(y.expr, x.expr) - } - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) - - ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4337fb2290fbc..96c28961e5aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -687,72 +687,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-23907: regression functions") { - val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") - val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) - .toDF("a", "b") - val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( - (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") - checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) - checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) - checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), - Row(null), absTol) - - - checkAggregatesWithTol(correlatedData.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), - absTol) - checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), - absTol) - } } From 5902125ac7ad25a0cb7aa3d98825c8290ee33c12 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sat, 12 May 2018 19:21:42 +0800 Subject: [PATCH 478/593] [SPARK-24198][SPARKR][SQL] Adding slice function to SparkR ## What changes were proposed in this pull request? The PR adds the `slice` function to SparkR. The function returns a subset of consecutive elements from the given array. ``` > df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) > tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) > head(select(tmp, slice(tmp$v1, 2L, 2L))) ``` ``` slice(v1, 2, 2) 1 6, 110 2 6, 110 3 4, 93 4 6, 110 5 8, 175 6 6, 105 ``` ## How was this patch tested? A test added into R/pkg/tests/fulltests/test_sparkSQL.R Author: Marek Novotny Closes #21298 from mn-mikke/SPARK-24198. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 +++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ 4 files changed, 27 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5f8209689a559..c575fe255f57a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -352,6 +352,7 @@ exportMethods("%<=>%", "sinh", "size", "skewness", + "slice", "sort_array", "soundex", "spark_partition_id", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4964594284aa0..77d70cb5d19e6 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -212,6 +212,7 @@ NULL #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, slice(tmp$v1, 2L, 2L))) #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) @@ -3142,6 +3143,22 @@ setMethod("size", column(jc) }) +#' @details +#' \code{slice}: Returns an array containing all the elements in x from the index start +#' (or starting from the end if start is negative) with the specified length. +#' +#' @rdname column_collection_functions +#' @param start an index indicating the first element occuring in the result. +#' @param length a number of consecutive elements choosen to the result. +#' @aliases slice slice,Column-method +#' @note slice since 2.4.0 +setMethod("slice", + signature(x = "Column"), + function(x, start, length) { + jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length) + column(jc) + }) + #' @details #' \code{sort_array}: Sorts the input array in ascending or descending order according to #' the natural ordering of the array elements. NA elements will be placed at the beginning of diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5faa51eef3abd..fbc4113e2becc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1194,6 +1194,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("slice", function(x, start, length) { standardGeneric("slice") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index b8bfded0ebf2d..2a550b9efb506 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1507,6 +1507,11 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) + # Test slice() + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L)))) + result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] + expect_equal(result, list(list(2L, 3L), list(5L))) + # Test flattern df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) From 348ddfd20f5b88777014f18a6374f33ee9b12731 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 May 2018 08:35:14 -0500 Subject: [PATCH 479/593] [BUILD] Close stale PRs Closes https://github.com/apache/spark/pull/20458 Closes https://github.com/apache/spark/pull/20530 Closes https://github.com/apache/spark/pull/20557 Closes https://github.com/apache/spark/pull/20966 Closes https://github.com/apache/spark/pull/20857 Closes https://github.com/apache/spark/pull/19694 Closes https://github.com/apache/spark/pull/18227 Closes https://github.com/apache/spark/pull/20683 Closes https://github.com/apache/spark/pull/20881 Closes https://github.com/apache/spark/pull/20347 Closes https://github.com/apache/spark/pull/20825 Closes https://github.com/apache/spark/pull/20078 Closes https://github.com/apache/spark/pull/21281 Closes https://github.com/apache/spark/pull/19951 Closes https://github.com/apache/spark/pull/20905 Closes https://github.com/apache/spark/pull/20635 Author: Sean Owen Closes #21303 from srowen/ClosePRs. From 32acfa78c60465efc03ae01e022614ad91345b1c Mon Sep 17 00:00:00 2001 From: Cody Allen Date: Sat, 12 May 2018 14:35:40 -0500 Subject: [PATCH 480/593] Improve implicitNotFound message for Encoder The `implicitNotFound` message for `Encoder` doesn't mention the name of the type for which it can't find an encoder. Furthermore, it covers up the fact that `Encoder` is the name of the relevant type class. Hopefully this new message provides a little more specific type detail while still giving the general message about which types are supported. ## What changes were proposed in this pull request? Augment the existing message to mention that it's looking for an `Encoder` and what the type of the encoder is. For example instead of: ``` Unable to find encoder for type stored in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. ``` return this message: ``` Unable to find encoder for type Exception. An implicit Encoder[Exception] is needed to store Exception instances in a Dataset. Primitive types (Int, String, etc) and Product types (ca se classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. ``` ## How was this patch tested? It was tested manually in the Scala REPL, since triggering this in a test would cause a compilation error. ``` scala> implicitly[Encoder[Exception]] :51: error: Unable to find encoder for type Exception. An implicit Encoder[Exception] is needed to store Exception instances in a Dataset. Primitive types (Int, String, etc) and Product types (ca se classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. implicitly[Encoder[Exception]] ^ ``` Author: Cody Allen Closes #20869 from ceedubs/encoder-implicit-msg. --- .../src/main/scala/org/apache/spark/sql/Encoder.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ccdb6bc5d4b7c..7b02317b8538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -68,10 +68,10 @@ import org.apache.spark.sql.types._ */ @Experimental @InterfaceStability.Evolving -@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + - "(Int, String, etc) and Product types (case classes) are supported by importing " + - "spark.implicits._ Support for serializing other types will be added in future " + - "releases.") +@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + + "classes) are supported by importing spark.implicits._ Support for serializing other types " + + "will be added in future releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ From 0d210ec8b610e4b0570ce730f3987dc86787c663 Mon Sep 17 00:00:00 2001 From: Kelley Robinson Date: Sun, 13 May 2018 13:19:03 -0700 Subject: [PATCH 481/593] [SPARK-24262][PYTHON] Fix typo in UDF type match error message ## What changes were proposed in this pull request? Updates `functon` to `function`. This was called out in holdenk's PyCon 2018 conference talk. Didn't see any existing PR's for this. holdenk happy to fix the Pandas.Series bug too but will need a bit more guidance. Author: Kelley Robinson Closes #21304 from robinske/master. --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8bb63fcc7ff9c..5d2e58bef6466 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -82,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " + raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " From 2fa33649d96394ae630a092a9f7e1261d1893f6e Mon Sep 17 00:00:00 2001 From: Fan Donglai Date: Sun, 13 May 2018 18:10:00 -0500 Subject: [PATCH 482/593] Update StreamingKMeans.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? I think the ‘n_t+t’ in the following code may be wrong, it shoud be ‘n_t+1’ that means is the number of points to the cluster after it finish the no.t+1 min-batch. *
    * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ * n_t+t &= n_t * a + m_t * \end{align} * $$ *
    Author: Fan Donglai Closes #21179 from ddna1021/master. --- .../org/apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3ca75e8cdb97a..7a5e520d5818e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ - * n_t+t &= n_t * a + m_t + * n_t+1 &= n_t * a + m_t * \end{align} * $$ * From 3f0e801c11e600ed28491924e550d3ba93f19c19 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 14 May 2018 09:48:54 +0800 Subject: [PATCH 483/593] [SPARK-24186][R][SQL] change reverse and concat to collection functions in R ## What changes were proposed in this pull request? reverse and concat are already in functions.R as column string functions. Since now these two functions are categorized as collection functions in scala and python, we will do the same in R. ## How was this patch tested? Add test in test_sparkSQL.R Author: Huaxin Gao Closes #21307 from huaxingao/spark_24186. --- R/pkg/R/functions.R | 35 ++++++++++++++------------- R/pkg/R/generics.R | 4 +-- R/pkg/tests/fulltests/test_sparkSQL.R | 17 +++++++++++-- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 77d70cb5d19e6..fcb3521f901ea 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,7 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -218,7 +218,10 @@ NULL #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) #' head(select(tmp3, map_values(tmp3$v3))) -#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} +#' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL #' Window functions for Column operations @@ -1260,9 +1263,9 @@ setMethod("quarter", }) #' @details -#' \code{reverse}: Reverses the string column and returns it as a new string column. +#' \code{reverse}: Returns a reversed string or an array with reverse order of elements. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases reverse reverse,Column-method #' @note reverse since 1.5.0 setMethod("reverse", @@ -2055,20 +2058,10 @@ setMethod("countDistinct", #' @details #' \code{concat}: Concatenates multiple input columns together into a single column. -#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. +#' The function works with strings, binary and compatible array columns. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases concat concat,Column-method -#' @examples -#' -#' \dontrun{ -#' # concatenate strings -#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), -#' s2 = concat(df$Class, df$Sex, df$Age), -#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), -#' s4 = concat_ws("_", df$Class, df$Sex), -#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) -#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2409,6 +2402,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex), +#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -3063,7 +3063,8 @@ setMethod("array_sort", }) #' @details -#' \code{flatten}: Transforms an array of arrays into a single array. +#' \code{flatten}: Creates a single array from an array of arrays. +#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. #' #' @rdname column_collection_functions #' @aliases flatten flatten,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fbc4113e2becc..61da30badac4e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -817,7 +817,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -1134,7 +1134,7 @@ setGeneric("regexp_replace", #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 2a550b9efb506..13b55ac6e6e3c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,7 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position() and element_at() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1496,6 +1496,13 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + result <- collect(select(df, reverse(df[[1]])))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L))) + + df2 <- createDataFrame(list(list("abc"))) + result <- collect(select(df2, reverse(df2[[1]])))[[1]] + expect_equal(result, "cba") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1512,7 +1519,13 @@ test_that("column functions", { result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] expect_equal(result, list(list(2L, 3L), list(5L))) - # Test flattern + # Test concat() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + list(list(7L, 8L, 9L), list(10L, 11L, 12L)))) + result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L))) + + # Test flatten() df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) result <- collect(select(df, flatten(df[[1]])))[[1]] From 7a2d4895c75d4c232c377876b61c05a083eab3c8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 14 May 2018 10:01:06 +0800 Subject: [PATCH 484/593] [SPARK-17916][SQL] Fix empty string being parsed as null when nullValue is set. ## What changes were proposed in this pull request? I propose to bump version of uniVocity parser up to 2.6.3 where quoted empty strings are replaced by the empty value (passed to `setEmptyValue`) instead of `null` values as in the current version 2.5.9: https://github.com/uniVocity/univocity-parsers/blob/v2.6.3/src/main/java/com/univocity/parsers/csv/CsvParser.java#L125 Empty value for writer is set to `""`. So, empty string in dataframe/dataset is stored as empty quoted string `""`. Empty value for reader is set to empty string (zero size). In this way, saved empty quoted string will be read as just empty string. Please, look at the tests for more details. Here are main changes made in [2.6.0](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.0), [2.6.1](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.1), [2.6.2](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.2), [2.6.3](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.3): - CSV parser now parses quoted values ~30% faster - CSV format detection process has option provide a list of possible delimiters, in order of priority ( i.e. settings.detectFormatAutomatically( '-', '.');) - https://github.com/uniVocity/univocity-parsers/issues/214 - Implemented trim quoted values support - https://github.com/uniVocity/univocity-parsers/issues/230 - NullPointer when stopping parser when nothing is parsed - https://github.com/uniVocity/univocity-parsers/issues/219 - Concurrency issue when calling stopParsing() - https://github.com/uniVocity/univocity-parsers/issues/231 Closes #20068 ## How was this patch tested? Added tests from the PR https://github.com/apache/spark/pull/20068 Author: Maxim Gekk Closes #21273 from MaxGekk/univocity-2.6. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- sql/core/pom.xml | 2 +- .../datasources/csv/CSVOptions.scala | 3 +- .../datasources/csv/CSVBenchmarks.scala | 80 +++++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 46 +++++++++++ 7 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f552b81fde9f4..e710e26348117 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -190,7 +190,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 024b1fca717df..97ad17a9ff7b1 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -191,7 +191,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 938de7bc06663..e21bfef8c4291 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -211,7 +211,7 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar token-provider-1.0.1.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar xbean-asm5-shaded-4.4.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ef41837f89d68..f270c70fbfcf0 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.9 + 2.6.3 jar diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index ed2dc65a47914..1066d156acd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -164,7 +164,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("\"\"") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -185,6 +185,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) + settings.setEmptyValue("") settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala new file mode 100644 index 0000000000000..d442ba7e59c61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure CSV read/write performance. + * To run this: + * spark-submit --class --jars + */ +object CSVBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-csv-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + + withTempPath { path => + val str = (0 until 10000).map(i => s""""$i"""").mkString(",") + + spark.range(rowsNum) + .map(_ => str) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val schema = new StructType().add("value", StringType) + val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"One quoted string", numIters) { _ => + ds.filter((_: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + One quoted string 30273 / 30549 0.0 605451.2 1.0X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 461abdd96d3f3..07e6c74b14d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1322,4 +1322,50 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } } From b6c50d7820aafab172835633fb0b35899e93146b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 May 2018 10:57:10 +0800 Subject: [PATCH 485/593] [SPARK-24228][SQL] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following Java lint errors due to importing unimport classes ``` $ dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java:[25] (sizes) LineLength: Line is longer than 100 characters (found 109). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java:[38] (sizes) LineLength: Line is longer than 100 characters (found 102). [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[21,8] (imports) UnusedImports: Unused import - java.io.ByteArrayInputStream. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java:[29,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java:[110] (sizes) LineLength: Line is longer than 100 characters (found 101). ``` With this PR ``` $ dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks passed. ``` ## How was this patch tested? Existing UTs. Also manually run checkstyles against these two files. Author: Kazuaki Ishizaki Closes #21301 from kiszk/SPARK-24228. --- .../datasources/parquet/SpecificParquetRecordReaderBase.java | 1 - .../datasources/parquet/VectorizedPlainValuesReader.java | 1 - .../sql/sources/v2/reader/partitioning/Distribution.java | 3 ++- .../sql/sources/v2/reader/streaming/ContinuousReader.java | 4 ++-- .../apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java | 3 ++- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 10d6ed85a4080..daedfd7e78f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index aacefacfc1c1a..c62dc3d86386e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -26,7 +26,6 @@ import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; -import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index d2ee9518d628f..5e32ba6952e1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -22,7 +22,8 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). + * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * partition). * Note that this interface has nothing to do with the data ordering inside one * partition(the output records of a single {@link InputPartitionReader}). * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 716c5c0e9e15a..6e960bedf8020 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -35,8 +35,8 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each - * partition to a single global offset. + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances + * for each partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 714638e500c94..445cb29f5ee3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -107,7 +107,8 @@ public List> planInputPartitions() { } } - static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { private int start; private int end; private StructType requiredSchema; From 1430fa80e37762e31cc5adc74cd609c215d84b6e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 14 May 2018 10:49:12 -0700 Subject: [PATCH 486/593] [SPARK-24263][R] SparkR java check breaks with openjdk ## What changes were proposed in this pull request? Change text to grep for. ## How was this patch tested? manual test Author: Felix Cheung Closes #21314 from felixcheung/openjdkver. --- R/pkg/R/client.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index e9295e05872bd..14a17c600b17f 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -82,7 +82,7 @@ checkJavaVersion <- function() { }) javaVersionFilter <- Filter( function(x) { - grepl("java version", x) + grepl(" version", x) }, javaVersionOut) javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] From c26f673252c2cbbccf8c395ba6d4ab80c098d60e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 14 May 2018 11:37:57 -0700 Subject: [PATCH 487/593] [SPARK-24246][SQL] Improve AnalysisException by setting the cause when it's available ## What changes were proposed in this pull request? If there is an exception, it's better to set it as the cause of AnalysisException since the exception may contain useful debug information. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #21297 from zsxwing/SPARK-24246. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- .../spark/sql/catalyst/analysis/ResolveInlineTables.scala | 2 +- .../org/apache/spark/sql/catalyst/analysis/package.scala | 5 +++++ .../org/apache/spark/sql/execution/datasources/rules.scala | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dfdcdbc1eb2c7..3eaa9ecf5d075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -676,13 +676,13 @@ class Analyzer( try { catalog.lookupRelation(tableIdentWithDb) } catch { - case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + case e: NoSuchTableException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e) // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exist.") + s"database ${e.db} doesn't exist.", e) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 4eb6e642b1c37..31ba9d792024b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -105,7 +105,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas castedExpr.eval() } catch { case NonFatal(ex) => - table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 7731336d247db..354a3fa0602a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -41,6 +41,11 @@ package object analysis { def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } + + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String, cause: Throwable): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause)) + } } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0dea767840ed3..cab00251622b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { case _: ClassNotFoundException => u case e: Exception => // the provider is valid, but failed to create a logical plan - u.failAnalysis(e.getMessage) + u.failAnalysis(e.getMessage, e) } } } From 075d678c8844614910b50abca07282bde31ef7e0 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 14 May 2018 13:35:54 -0700 Subject: [PATCH 488/593] [SPARK-24155][ML] Instrumentation improvements for clustering ## What changes were proposed in this pull request? changed the instrument for all of the clustering methods ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21218 from ludatabricks/SPARK-23686-1. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 7 +++++-- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 5 ++++- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 +++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 438e53ba6197c..1ad4e097246a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -261,8 +261,9 @@ class BisectingKMeans @Since("2.0.0") ( transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, rdd) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val instr = Instrumentation.create(this, dataset) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() .setK($(k)) @@ -275,6 +276,8 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88d618c3a03a8..3091bb5a2e54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -352,7 +352,7 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) @@ -425,6 +425,9 @@ class GaussianMixture @Since("2.0.0") ( val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) + instr.logNamedValue("logLikelihood", logLikelihood) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 97f246fbfd859..e72d7f9485e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -342,7 +342,7 @@ class KMeans @Since("1.5.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -359,6 +359,8 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() From 8cd83acf4075d369bfcf9e703760d4946ef15f00 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 14 May 2018 14:05:42 -0700 Subject: [PATCH 489/593] [SPARK-24027][SQL] Support MapType with StringType for keys as the root type by from_json ## What changes were proposed in this pull request? Currently, the from_json function support StructType or ArrayType as the root type. The PR allows to specify MapType(StringType, DataType) as the root type additionally to mentioned types. For example: ```scala import org.apache.spark.sql.types._ val schema = MapType(StringType, IntegerType) val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() in.select(from_json($"value", schema, Map[String, String]())).collect() ``` ``` res1: Array[org.apache.spark.sql.Row] = Array([Map(a -> 1, b -> 2, c -> 3)]) ``` ## How was this patch tested? It was checked by new tests for the map type with integer type and struct type as value types. Also roundtrip tests like from_json(to_json) and to_json(from_json) for MapType are added. Author: Maxim Gekk Author: Maxim Gekk Closes #21108 from MaxGekk/from_json-map-type. --- python/pyspark/sql/functions.py | 10 ++- .../expressions/jsonExpressions.scala | 10 ++- .../sql/catalyst/json/JacksonParser.scala | 18 ++++- .../org/apache/spark/sql/functions.scala | 29 ++++---- .../apache/spark/sql/JsonFunctionsSuite.scala | 66 +++++++++++++++++++ 5 files changed, 113 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b62748e9a2d6c..6866c1cf9f882 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2095,12 +2095,13 @@ def json_tuple(col, *fields): return Column(jc) +@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` - of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an - unparseable string. + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -2117,6 +2118,9 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] + >>> schema = MapType(StringType(), IntegerType()) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 34161f0f03f4a..04a4eb0ffc032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -548,7 +548,7 @@ case class JsonToStructs( forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) => + case _: StructType | ArrayType(_: StructType, _) | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") @@ -558,6 +558,7 @@ case class JsonToStructs( lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st + case mt: MapType => mt } // This converts parsed rows to the desired output by the given schema. @@ -567,6 +568,8 @@ case class JsonToStructs( (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: MapType => + (rows: Seq[InternalRow]) => rows.head.getMap(0) } @transient @@ -613,6 +616,11 @@ case class JsonToStructs( } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def sql: String = schema match { + case _: MapType => "entries" + case _ => super.sql + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a5a4a13eb608b..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( - schema: StructType, + schema: DataType, val options: JSONOptions) extends Logging { import JacksonUtils._ @@ -57,7 +57,14 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { + private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) val fieldConverters = st.map(_.dataType).map(makeConverter).toArray (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { @@ -87,6 +94,13 @@ class JacksonParser( } } + private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) { + case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3c9ace407a58e..b71dfdad8aa9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3231,9 +3231,9 @@ object functions { from_json(e, schema.asInstanceOf[DataType], options) /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3263,9 +3263,9 @@ object functions { from_json(e, schema, options.asScala.toMap) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3292,8 +3292,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, + * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3305,9 +3306,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3322,9 +3323,9 @@ object functions { } /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string, it could be a diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 00d2acc4a1d8a..055e1fc5640f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -326,4 +326,70 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg4.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() + val schema = + """ + |{ + | "type" : "map", + | "keyType" : "string", + | "valueType" : "integer", + | "valueContainsNull" : true + |} + """.stripMargin + val out = in.select(from_json($"value", schema, Map[String, String]())) + + assert(out.columns.head == "entries") + checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, new StructType().add("b", IntegerType), true) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Row(1)))) + } + + test("SPARK-24027: from_json - map>") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, MapType(StringType, IntegerType)) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) + } + + test("SPARK-24027: roundtrip - from_json -> to_json - map") { + val json = """{"a":1,"b":2,"c":3}""" + val schema = MapType(StringType, IntegerType, true) + val out = Seq(json).toDS().select(to_json(from_json($"value", schema))) + + checkAnswer(out, Row(json)) + } + + test("SPARK-24027: roundtrip - to_json -> from_json - map") { + val in = Seq(Map("a" -> 1)).toDF() + val schema = MapType(StringType, IntegerType, true) + val out = in.select(from_json(to_json($"value"), schema)) + + checkAnswer(out, in) + } + + test("SPARK-24027: from_json - wrong map") { + val in = Seq("""{"a" 1}""").toDS() + val schema = MapType(StringType, IntegerType) + val out = in.select(from_json($"value", schema, Map[String, String]())) + + checkAnswer(out, Row(null)) + } + + test("SPARK-24027: from_json of a map with unsupported key type") { + val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) + + checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + } } From 061e0084ce19c1384ba271a97a0aa1f87abe879d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 14 May 2018 14:35:08 -0700 Subject: [PATCH 490/593] [SPARK-23852][SQL] Add withSQLConf(...) to test case ## What changes were proposed in this pull request? Add a `withSQLConf(...)` wrapper to force Parquet filter pushdown for a test that relies on it. ## How was this patch tested? Test passes Author: Henry Robinson Closes #21323 from henryr/spark-23582. --- .../datasources/parquet/ParquetFilterSuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4d0ecdef60986..90da7eb8c4fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -650,13 +650,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-23852: Broken Parquet push-down for partially-written stats") { - // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. - // The row-group statistics include null counts, but not min and max values, which - // triggers PARQUET-1217. - val df = readResourceParquetFile("test-data/parquet-1217.parquet") + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. + // The row-group statistics include null counts, but not min and max values, which + // triggers PARQUET-1217. + val df = readResourceParquetFile("test-data/parquet-1217.parquet") - // Will return 0 rows if PARQUET-1217 is not fixed. - assert(df.where("col > 0").count() === 2) + // Will return 0 rows if PARQUET-1217 is not fixed. + assert(df.where("col > 0").count() === 2) + } } } From 9059f1ee6ae13c8636c9b7fdbb708a349256fb8e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 14 May 2018 19:20:25 -0700 Subject: [PATCH 491/593] [SPARK-23780][R] Failed to use googleVis library with new SparkR ## What changes were proposed in this pull request? change generic to get it to work with googleVis also fix lintr ## How was this patch tested? manual test, unit tests Author: Felix Cheung Closes #21315 from felixcheung/googvis. --- R/pkg/R/client.R | 5 +++-- R/pkg/R/generics.R | 2 +- R/pkg/R/sparkR.R | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 14a17c600b17f..4c87f64e7f0e1 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -63,7 +63,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack checkJavaVersion <- function() { javaBin <- "java" javaHome <- Sys.getenv("JAVA_HOME") - javaReqs <- utils::packageDescription(utils::packageName(), fields=c("SystemRequirements")) + javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements")) sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) if (javaHome != "") { javaBin <- file.path(javaHome, "bin", javaBin) @@ -90,7 +90,8 @@ checkJavaVersion <- function() { # Extract 8 from it to compare to sparkJavaVersion javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) if (javaVersionNum != sparkJavaVersion) { - stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", + javaVersionStr)) } } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 61da30badac4e..3ea181157b644 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -624,7 +624,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) +setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d6a2d08f9c218..f7c1663d32c96 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -194,7 +194,7 @@ sparkR.sparkContext <- function( # Don't use readString() so that we can provide a useful # error message if the R and Java versions are mismatched. - authSecretLen = readInt(f) + authSecretLen <- readInt(f) if (length(authSecretLen) == 0 || authSecretLen == 0) { stop("Unexpected EOF in JVM connection data. Mismatched versions?") } From e29176fd7dbcef04a29c4922ba655d58144fed24 Mon Sep 17 00:00:00 2001 From: Goun Na Date: Tue, 15 May 2018 14:11:20 +0800 Subject: [PATCH 492/593] [SPARK-23627][SQL] Provide isEmpty in Dataset ## What changes were proposed in this pull request? This PR adds isEmpty() in DataSet ## How was this patch tested? Unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Goun Na Author: goungoun Closes #20800 from goungoun/SPARK-23627. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d518e07bfb62c..f001f16e1d5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -511,6 +511,16 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if the `Dataset` is empty. + * + * @group basic + * @since 2.4.0 + */ + def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) == 0 + } + /** * Returns true if this Dataset contains one or more sources that continuously * return data as it arrives. A Dataset that reads data from a streaming source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..d477d78dc14e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23627: provide isEmpty in DataSet") { + val ds1 = spark.emptyDataset[Int] + val ds2 = Seq(1, 2, 3).toDS() + + assert(ds1.isEmpty == true) + assert(ds2.isEmpty == false) + } + test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] From 80c6d35a3edbfb2e053c7d6650e2f725c36af53e Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 14 May 2018 23:34:42 -0700 Subject: [PATCH 493/593] [SPARK-24035][SQL] SQL syntax for Pivot - fix antlr warning ## What changes were proposed in this pull request? 1. Change antlr rule to fix the warning. 2. Add PIVOT/LATERAL check in AstBuilder with a more meaningful error message. ## How was this patch tested? 1. Add a counter case in `PlanParserSuite.test("lateral view")` Author: maryannxue Closes #21324 from maryannxue/spark-24035-fix. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 +++ .../spark/sql/catalyst/parser/PlanParserSuite.scala | 10 ++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f7f921ec22c35..7c54851097af3 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* (pivotClause | lateralView*)? + : FROM relation (',' relation)* lateralView* pivotClause? ; aggregation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 64eed23884584..b9ece295c2510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -504,6 +504,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging withJoinRelations(join, relation) } if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } withPivot(ctx.pivotClause, from) } else { ctx.lateralView.asScala.foldLeft(from)(withGenerate) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 812bfdd7bb885..fb51376c6163f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", expected) + + intercept( + """select * + |from t + |lateral view explode(x) expl + |pivot ( + | sum(x) + | FOR y IN ('a', 'b') + |)""".stripMargin, + "LATERAL cannot be used together with PIVOT in FROM clause") } test("joins") { From 4a2b15f0af400c71b7f20b2048f38a8b74d43dfa Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 15 May 2018 16:04:17 +0800 Subject: [PATCH 494/593] [SPARK-24241][SUBMIT] Do not fail fast when dynamic resource allocation enabled with 0 executor ## What changes were proposed in this pull request? ``` ~/spark-2.3.0-bin-hadoop2.7$ bin/spark-sql --num-executors 0 --conf spark.dynamicAllocation.enabled=true Java HotSpot(TM) 64-Bit Server VM warning: ignoring option PermSize=1024m; support was removed in 8.0 Java HotSpot(TM) 64-Bit Server VM warning: ignoring option MaxPermSize=1024m; support was removed in 8.0 Error: Number of executors must be a positive number Run with --help for usage help or --verbose for debug output ``` Actually, we could start up with min executor number with 0 before if dynamically ## How was this patch tested? ut added Author: Kent Yao Closes #21290 from yaooqinn/SPARK-24241. --- .../spark/deploy/SparkSubmitArguments.scala | 7 +++++-- .../spark/deploy/SparkSubmitSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0733fdb72cafb..fed4e0a5069c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils - /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var proxyUser: String = null var principal: String = null var keytab: String = null + private var dynamicAllocationEnabled: Boolean = false // Standalone cluster mode only var supervise: Boolean = false @@ -198,6 +198,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + dynamicAllocationEnabled = + sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -274,7 +276,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { error("Total executor cores must be a positive number") } - if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + if (!dynamicAllocationEnabled && + numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7451e07b25a1f..43286953e4383 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -180,6 +180,26 @@ class SparkSubmitSuite appArgs.toString should include ("thequeue") } + test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") { + val clArgs1 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=true", + "thejar.jar") + new SparkSubmitArguments(clArgs1) + + val clArgs2 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=false", + "thejar.jar") + + val e = intercept[SparkException](new SparkSubmitArguments(clArgs2)) + assert(e.getMessage.contains("Number of executors must be a positive number")) + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", From d610d2a3f57ca551f72cb4e5dfed78f27be62eec Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 May 2018 22:06:58 +0800 Subject: [PATCH 495/593] [SPARK-24259][SQL] ArrayWriter for Arrow produces wrong output ## What changes were proposed in this pull request? Right now `ArrayWriter` used to output Arrow data for array type, doesn't do `clear` or `reset` after each batch. It produces wrong output. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21312 from viirya/SPARK-24259. --- python/pyspark/sql/tests.py | 20 +++++++++++++++++++ .../sql/execution/arrow/ArrowWriter.scala | 8 ++++++++ 2 files changed, 28 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa9378ad8ee..a1b6db71782bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4680,6 +4680,26 @@ def test_supported_types(self): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b63513548fe..66888fce7f9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter { valueVector match { case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case listVector: ListVector => + // Manual "reset" the underlying buffer. + // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call + // `listVector.reset()`. + val buffers = listVector.getBuffers(false) + buffers.foreach(buf => buf.setZero(0, buf.capacity())) + listVector.setValueCount(0) + listVector.setLastSet(0) case _ => } count = 0 From 3fabbc576203c7fd63808a259adafc5c3cea1838 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 15 May 2018 10:25:29 -0700 Subject: [PATCH 496/593] [SPARK-24040][SS] Support single partition aggregates in continuous processing. ## What changes were proposed in this pull request? Support aggregates with exactly 1 partition in continuous processing. A few small tweaks are needed to make this work: * Replace currentEpoch tracking with an ThreadLocal. This means that current epoch is scoped to a task rather than a node, but I think that's sustainable even once we add shuffle. * Add a new testing-only flag to disable the UnsupportedOperationChecker whitelist of allowed continuous processing nodes. I think this is preferable to writing a pile of custom logic to enforce that there is in fact only 1 partition; we plan to support multi-partition aggregates before the next Spark release, so we'd just have to tear that logic back out. * Restart continuous processing queries from the first available uncommitted epoch, rather than one that's guaranteed to be unused. This is required for stateful operators to overwrite partial state from the previous attempt at the epoch, and there was no specific motivation for the original strategy. In another PR before stabilizing the StreamWriter API, we'll need to narrow down and document more precise semantic guarantees for the epoch IDs. * We need a single-partition ContinuousMemoryStream. The way MemoryStream is constructed means it can't be a text option like it is for rate source, unfortunately. ## How was this patch tested? new unit tests Author: Jose Torres Closes #21239 from jose-torres/withAggr. --- .../UnsupportedOperationChecker.scala | 1 + .../continuous/ContinuousExecution.scala | 11 +-- .../ContinuousQueuedDataReader.scala | 7 +- .../continuous/ContinuousWriteRDD.scala | 18 +++-- .../streaming/continuous/EpochTracker.scala | 58 +++++++++++++++ .../sources/ContinuousMemoryStream.scala | 14 ++-- .../streaming/state/StateStoreRDD.scala | 10 ++- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../ContinuousAggregationSuite.scala | 72 +++++++++++++++++++ .../ContinuousQueuedDataReaderSuite.scala | 1 + 10 files changed, 167 insertions(+), 29 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d3d6c636c4ba8..2bed41672fe33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f58146ac42398..0e7d1019b9c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -122,16 +122,7 @@ class ContinuousExecution( s"Batch $latestEpochId was committed without end epoch offsets!") } committedOffsets = nextOffsets.toStreamProgress(sources) - - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index d8645576c2052..f38577b6a9f16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -46,8 +46,6 @@ class ContinuousQueuedDataReader( // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = ContinuousDataSourceRDD.getContinuousReader(reader).getOffset - private var currentEpoch: Long = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong /** * The record types in the read buffer. @@ -115,8 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), currentEpoch, currentOffset)) - currentEpoch += 1 + context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -184,7 +181,7 @@ class ContinuousQueuedDataReader( private val epochCoordEndpoint = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That // field represents the epoch wrt the data being processed. The currentEpoch here is just a // counter to ensure we send the appropriate number of markers if we fall behind the driver. private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 91f1576581511..ef5f0da1e7cc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -45,7 +45,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null @@ -54,19 +55,24 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor try { val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) + context.partitionId(), + context.attemptNumber(), + EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) } logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch is committing.") + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") val msg = dataWriter.commit() epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + CommitPartitionEpoch( + context.partitionId(), + EpochTracker.getCurrentEpoch.get, + msg) ) logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch committed.") - currentEpoch += 1 + s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.") + EpochTracker.incrementCurrentEpoch() } catch { case _: InterruptedException => // Continuous shutdown always involves an interrupt. Just finish the task. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala new file mode 100644 index 0000000000000..bc0ae428d4521 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +/** + * Tracks the current continuous processing epoch within a task. Call + * EpochTracker.getCurrentEpoch to get the current epoch. + */ +object EpochTracker { + // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will + // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. + private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { + override def initialValue() = new AtomicLong(-1) + } + + /** + * Get the current epoch for the current task, or None if the task has no current epoch. + */ + def getCurrentEpoch: Option[Long] = { + currentEpoch.get().get() match { + case n if n < 0 => None + case e => Some(e) + } + } + + /** + * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * between epochs. + */ + def incrementCurrentEpoch(): Unit = { + currentEpoch.get().incrementAndGet() + } + + /** + * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * at the beginning of a task. + */ + def initializeCurrentEpoch(startEpoch: Long): Unit = { + currentEpoch.get().set(startEpoch) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index fef792eab69d5..4daafa65850de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -47,10 +47,9 @@ import org.apache.spark.util.RpcUtils * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified * offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) - private val NUM_PARTITIONS = 2 protected val logicalPlan = StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) @@ -58,7 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) // ContinuousReader implementation @GuardedBy("this") - private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + private val records = Seq.fill(numPartitions)(new ListBuffer[A]) @GuardedBy("this") private var startOffset: ContinuousMemoryStreamOffset = _ @@ -69,17 +68,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = synchronized { // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { - case (item, index) => records(index % NUM_PARTITIONS) += item + case (item, index) => records(index % numPartitions) += item } // The new target offset is the offset where all records in all partitions have been processed. - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } override def setStartOffset(start: Optional[Offset]): Unit = synchronized { // Inferred initial offset is position 0 in each partition. startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) }.asInstanceOf[ContinuousMemoryStreamOffset] } @@ -152,6 +151,9 @@ object ContinuousMemoryStream { def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 01d8e75980993..3f11b8f79943c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => storeVersion + case Some(value) => value + } + store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7cefd03e43bc3..97da2b1325f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -242,7 +242,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo (sink, trigger) match { case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala new file mode 100644 index 0000000000000..b7ef637f5270e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode + +class ContinuousAggregationSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("not enabled") { + val ex = intercept[AnalysisException] { + val input = ContinuousMemoryStream.singlePartition[Int] + testStream(input.toDF().agg(max('value)), OutputMode.Complete)() + } + + assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + } + + test("basic") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + } + + test("repeated restart") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + StartStream(), + StopStream, + StartStream(), + StopStream, + StartStream(), + AddData(input, 0), + CheckAnswer(2), + AddData(input, 5), + CheckAnswer(5)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index f47d3ec8ae025..e663fa8312da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -51,6 +51,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { startEpoch, spark, SparkEnv.get) + EpochTracker.initializeCurrentEpoch(0) } override def afterEach(): Unit = { From 6b94420f6c672683678a54404e6341a0b9ab3c24 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 15 May 2018 14:16:31 -0700 Subject: [PATCH 497/593] [SPARK-24231][PYSPARK][ML] Provide Python API for evaluateEachIteration for spark.ml GBTs ## What changes were proposed in this pull request? Add evaluateEachIteration for GBTClassification and GBTRegressionModel ## How was this patch tested? doctest Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21335 from ludatabricks/SPARK-14682. --- python/pyspark/ml/classification.py | 15 +++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653a1adf9..424ecfd89b060 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] .. versionadded:: 1.4.0 """ @@ -1319,6 +1323,17 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87d7f211..dd0b62f184d26 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1156,6 +1160,20 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, From 8a13c5096898f95d1dfcedaf5d31205a1cbf0a19 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 May 2018 16:50:09 -0700 Subject: [PATCH 498/593] [SPARK-24058][ML][PYSPARK] Default Params in ML should be saved separately: Python API ## What changes were proposed in this pull request? See SPARK-23455 for reference. Now default params in ML are saved separately in metadata file in Scala. We must change it for Python for Spark 2.4.0 as well in order to keep them in sync. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21153 from viirya/SPARK-24058. --- python/pyspark/ml/tests.py | 38 ++++++++++++++++++++++++++++++++++++++ python/pyspark/ml/util.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 093593132e56d..0dde0db9e3339 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1595,6 +1595,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..9fa85664939b8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ From 943493b165185c5362c8350dd355276cc458aad0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 16 May 2018 22:01:24 +0800 Subject: [PATCH 499/593] =?UTF-8?q?Revert=20"[SPARK-22938][SQL][FOLLOWUP]?= =?UTF-8?q?=20Assert=20that=20SQLConf.get=20is=20acces=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …sed only on the driver" This reverts commit a4206d58e05ab9ed6f01fee57e18dee65cbc4efc. This is from https://github.com/apache/spark/pull/21299 and to ease the review of it. Author: Wenchen Fan Closes #21341 from cloud-fan/revert. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++---------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +-- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 ++--- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 140 insertions(+), 188 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 94b0561529e71..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - val widerType = TypeCoercion.findWiderTypeForTwo( - dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) - if (widerType.isEmpty) { + if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 31ba9d792024b..71ed75454cd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( - inputTypes, conf.caseSensitiveAnalysis) - val tpe = wideType.getOrElse { + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a7ba201509b78..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes(conf) :: + WidenSetOperationTypes :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion(conf) :: + FunctionArgumentConversion :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion(conf) :: - IfCoercion(conf) :: + CaseWhenCoercion :: + IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts(conf) :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,10 +83,7 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - def findTightestCommonType( - left: DataType, - right: DataType, - caseSensitive: Boolean): Option[DataType] = (left, right) match { + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -105,32 +102,22 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => - val isSameType = if (caseSensitive) { - DataType.equalsIgnoreNullability(t1, t2) - } else { - DataType.equalsIgnoreCaseAndNullability(t1, t2) - } - - if (isSameType) { - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since t1 is same type of t2, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - } else { - None - } + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2, caseSensitive) - val valueType = findTightestCommonType(vt1, vt2, caseSensitive) + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -185,14 +172,13 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2, caseSensitive) - .map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -207,8 +193,7 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -216,7 +201,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeForTwo(d, c) case _ => None }) } @@ -228,22 +213,20 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType, - caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) + findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } @@ -296,32 +279,29 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes( - children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -336,19 +316,18 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType], - caseSensitive: Boolean): Seq[DataType] = { + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) + getWidestTypes(children, attrIndex + 1, castedTypes) } } @@ -453,7 +432,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) + .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -482,7 +461,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { + findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -536,7 +515,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { + object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -544,7 +523,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -552,7 +531,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -563,7 +542,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -573,7 +552,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -601,7 +580,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -611,14 +590,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -658,11 +637,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { + object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -689,17 +668,16 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { + object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -798,11 +776,12 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -825,18 +804,17 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b1965c438e27..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") - } - confGetter.get()() - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1280,6 +1274,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4ee12db9c10ca..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true @@ -214,7 +218,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index f73e045685ee1..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType, Boolean) => Option[DataType], + widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) + var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) + found = widenFunc(t2, t1) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion(conf) + val rule = TypeCoercion.FunctionArgumentConversion val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion(conf) + val rule = TypeCoercion.IfCoercion val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion(conf), Division) + val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1edf27619ad7b..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -522,8 +521,6 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => - TypeCoercion.findWiderTypeForTwo( - t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e0424b7478122..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +44,6 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -55,7 +53,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions, caseSensitive)) + Some(inferField(parser, configOptions)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -70,7 +68,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -100,15 +98,14 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField( - parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions, caseSensitive) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -125,7 +122,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions, caseSensitive), + inferField(parser, configOptions), nullable = true) } val fields: Array[StructField] = builder.result() @@ -140,7 +137,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) + elementType, inferField(parser, configOptions)) } ArrayType(elementType) @@ -246,14 +243,13 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode, - caseSensitive: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -263,7 +259,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) + case (ty1, ty2) => compatibleType(ty1, ty2) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -271,8 +267,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { - TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -307,8 +303,7 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType( - fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -331,17 +326,15 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType( - compatibleType(elementType1, elementType2, caseSensitive), - containsNull1 || containsNull2) + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2, caseSensitive) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2), caseSensitive) + compatibleType(t1, DecimalType.forType(t2)) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 34d23ee53220d..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) + var actual = compatibleType(t1, t2) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) + actual = compatibleType(t2, t1) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 6fb7d6c4f71be0007942f7d1fc3099f1bcf8c52b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 17 May 2018 00:40:39 +0800 Subject: [PATCH 500/593] [SPARK-24275][SQL] Revise doc comments in InputPartition ## What changes were proposed in this pull request? In #21145, DataReaderFactory is renamed to InputPartition. This PR is to revise wording in the comments to make it more clear. ## How was this patch tested? None Author: Gengliang Wang Closes #21326 from gengliangwang/revise_reader_comments. --- .../spark/sql/sources/v2/ReadSupport.java | 2 +- .../sql/sources/v2/ReadSupportWithSchema.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 2 +- .../sql/sources/v2/reader/DataSourceReader.java | 16 ++++++++-------- .../sql/sources/v2/reader/InputPartition.java | 17 +++++++++-------- .../sql/sources/v2/writer/DataSourceWriter.java | 6 +++--- .../sources/v2/writer/DataWriterFactory.java | 2 +- 7 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 0ea4dc6b5def3..b2526ded53d92 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 { /** * Creates a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param options the options for the returned data source reader, which is an immutable diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 3801402268af1..f31659904cc53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 { /** * Create a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param schema the full schema of this data source reader. Full schema usually maps to the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index cab56453816cc..83aeec0c47853 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 { * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index f898c296e4245..36a3e542b5a11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,7 +31,7 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s that are returned by + * logic is delegated to {@link InputPartition}s, which are returned by * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: @@ -45,8 +45,8 @@ * only one of them would be respected, according to the priority list from high to low: * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * - * If an exception was throw when applying any of these query optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these query optimizations, the action will fail + * and no Spark job will be submitted. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark @@ -59,21 +59,21 @@ public interface DataSourceReader { * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for creating a data reader to - * output data for one RDD partition. That means the number of tasks returned here is same as - * the number of RDD partitions this scan outputs. + * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for + * creating a data reader to output data of one RDD partition. The number of input partitions + * returned here is the same as the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ List> planInputPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index c581e3b5d0047..3524481784fea 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -23,13 +23,14 @@ /** * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader. The relationship between - * {@link InputPartition} and {@link InputPartitionReader} + * responsible for creating the actual data reader of one RDD partition. + * The relationship between {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that input partitions will be serialized and sent to executors, then the partition reader - * will be created on executors and do the actual reading. So {@link InputPartition} must be - * serializable and {@link InputPartitionReader} doesn't need to be. + * Note that {@link InputPartition}s will be serialized and sent to executors, then + * {@link InputPartitionReader}s will be created on executors to do the actual reading. So + * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to + * be. */ @InterfaceStability.Evolving public interface InputPartition extends Serializable { @@ -41,10 +42,10 @@ public interface InputPartition extends Serializable { * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in - * the returned locations. By default this method returns empty string array, which means this - * task has no location preference. + * the returned locations. The default return value is empty string array, which means this + * input partition's reader has no location preference. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ default String[] preferredLocations() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0a0fd8db58035..0030a9f05dba7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -34,8 +34,8 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * - * If an exception was throw when applying any of these writing optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these writing optimizations, the action will fail + * and no Spark job will be submitted. * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the @@ -58,7 +58,7 @@ public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ DataWriterFactory createWriterFactory(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e8..7527bcc0c4027 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param partitionId A unique id of the RDD partition that the returned writer will process. From 8e60a16b73490007fe1c480d77cc09d760f0a02b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 16 May 2018 13:34:54 -0700 Subject: [PATCH 501/593] [SPARK-23601][BUILD][FOLLOW-UP] Keep md5 checksums for nexus artifacts. The repository.apache.org server still requires md5 checksums or it won't publish the staging repo. Author: Marcelo Vanzin Closes #21338 from vanzin/SPARK-23601. --- dev/create-release/release-build.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c00b00b845401..5faa3d3260a56 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -371,11 +371,18 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc and .sha1 - it really doesn't like anything else there + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 991726f31a8d182ed6d5b0e59185d97c0c5c532f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 May 2018 14:55:02 -0700 Subject: [PATCH 502/593] [SPARK-24158][SS] Enable no-data batches for streaming joins ## What changes were proposed in this pull request? This is a continuation of the larger task of enabling zero-data batches for more eager state cleanup. This PR enables it for stream-stream joins. ## How was this patch tested? - Updated join tests. Additionally, updated them to not use `CheckLastBatch` anywhere to set good precedence for future. Author: Tathagata Das Closes #21253 from tdas/SPARK-24158. --- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 14 +- .../spark/sql/streaming/StreamTest.scala | 15 +- .../sql/streaming/StreamingJoinSuite.scala | 217 +++++++++--------- 4 files changed, 130 insertions(+), 118 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..37a0b9d6c8728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -361,7 +361,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( - "Stream stream joins without equality predicate is not supported", plan = Some(plan)) + "Stream-stream join without equality predicate is not supported", plan = Some(plan)) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index fa7c8ee906ecd..afa664eb76525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val watermarkUsedForStateCleanup = + stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty + + // Latest watermark value is more than that used in this previous executed plan + val watermarkHasChanged = + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + + watermarkUsedForStateCleanup && watermarkHasChanged + } + protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) @@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec( // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. val cleanupIter = joinType match { - case Inner => - leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() case _ => throwBadJoinTypeException() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 9d139a927bea5..f348dac1319cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -199,15 +199,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName" - private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - - private def operatorName = "CheckNewAnswer" + override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}" } object CheckNewAnswer { @@ -218,6 +215,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) } + + def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows) } /** Stops the stream. It must currently be running. */ @@ -747,7 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } } - pos += 1 } try { @@ -761,8 +759,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { actns.foreach(executeAction) } + pos += 1 - case action: StreamAction => executeAction(action) + case action: StreamAction => + executeAction(action) + pos += 1 } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index da8f9608c1e9c..1f62357e6d09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1), CheckAnswer(), AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckLastBatch((10, 20, 30)), + CheckNewAnswer((10, 20, 30)), AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), StopStream, StartStream(), AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckLastBatch((1, 2, 3), (1, 2, 3)), + CheckNewAnswer((1, 2, 3), (1, 2, 3)), StopStream, StartStream(), AddData(input1, 100), AddData(input2, 100), - CheckLastBatch((100, 200, 300)) + CheckNewAnswer((100, 200, 300)) ) } @@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( AddData(input1, 1), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckNewAnswer((1, 10, 2, 3)), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), + CheckNewAnswer(), StopStream, StartStream(), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), + CheckNewAnswer((25, 30, 50, 75)), AddData(input1, 1), - CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark StopStream, StartStream(), AddData(input1, 5), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 5), - CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark ) } @@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 1, updated = 1), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckAnswer((1, 10, 2, 3)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + CheckNewAnswer((25, 30, 50, 75)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input2, 1), - CheckLastBatch(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 AddData(input1, 5), - CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + CheckNewAnswer(), // Same reason as above assertNumStateRows(total = 2, updated = 0) ) } @@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5)), CheckAnswer(), AddData(rightInput, (1, 11)), - CheckLastBatch((1, 5, 11)), + CheckNewAnswer((1, 5, 11)), AddData(rightInput, (1, 10)), - CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), - CheckLastBatch((1, 3, 10), (1, 3, 11)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer(), // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - - // Run another batch with event time = 25 to clear right state where rightTime <= 25 - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), // New data to right input should match with left side (1, 3) and (1, 5), as left state should // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and // state rows with rightTime <= 25 should be removed from state. // (1, 20) ==> filtered by event time watermark = 20 // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as state watermark = 25 + // as 21 < state watermark = 25 // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1), // New data to left input with leftTime <= 20 should be filtered due to event time watermark AddData(leftInput, (1, 20), (1, 21)), - CheckLastBatch((1, 21, 28)), - assertNumStateRows(total = 7, updated = 1) + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1) ) } @@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 20)), CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), - CheckLastBatch(), // matches with nothing on the left + CheckNewAnswer(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 5), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) // Should drop < 20 from left, i.e., none // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) // Should drop < 25 from the right, i.e., 14 and 15 - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat - CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // only 31 added // Advance the watermark AddData(rightInput, (1, 80)), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 1), - + CheckNewAnswer(), // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) // Should drop < 36 from left, i.e., 20, 31 (30 was not added) // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - AddData(rightInput, (1, 50)), - CheckLastBatch((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 50 added ) } @@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with input1.addData(1) q.awaitTermination(10000) } - assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } test("stream stream self join") { @@ -404,10 +402,11 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1, 5), AddData(input2, 1, 5, 10), AddData(input3, 5, 10), - CheckLastBatch((5, 10, 5, 15, 5, 25))) + CheckNewAnswer((5, 10, 5, 15, 5, 25))) } } + class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { import testImplicits._ @@ -465,13 +464,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -492,15 +491,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with value <= 7 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The right rows with rightValue <= 7 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -521,15 +520,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with value <= 4 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The left rows with leftValue <= 4 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -552,13 +551,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -568,14 +567,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -586,14 +585,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -627,21 +626,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), AddData(rightInput, (1, 10), (2, 5)), - CheckLastBatch((1, 1, 5, 10)), + CheckNewAnswer((1, 1, 5, 10)), AddData(rightInput, (1, 11)), - CheckLastBatch(), // no match as left time is too low + CheckNewAnswer(), // no match as left time is too low assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), - CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)), assertNumStateRows(total = 7, updated = 2), - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 1), - AddData(rightInput, (0, 30)), - CheckLastBatch(outerResult), - assertNumStateRows(total = 3, updated = 1) + AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls + CheckNewAnswer(outerResult), + assertNumStateRows(total = 2, updated = 1) ) } } @@ -665,36 +661,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), - CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added + + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state + AddData(rightInput, 20), - CheckLastBatch( - Row(20, 30, 40, 60)), + CheckNewAnswer(Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows - MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), - CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - MultiAddData(leftInput, 70)(rightInput, 71), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 6), // all inputs added since last check + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31 + CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)), + assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state + + MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state + AddData(rightInput, 70), - CheckLastBatch((70, 80, 140, 210)), + CheckNewAnswer((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left - MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), - CheckLastBatch(), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91 + CheckNewAnswer(), + assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state + MultiAddData(leftInput, 1000)(rightInput, 1001), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added - AddData(rightInput, 1000), - CheckLastBatch( - Row(1000, 1010, 2000, 3000), + CheckNewAnswer( Row(101, 110, 202, null), Row(102, 110, 204, null), Row(103, 110, 206, null)), - assertNumStateRows(total = 3, updated = 1) + assertNumStateRows(total = 2, updated = 2) ) } } From bfd75cdfb22a8c2fb005da597621e1ccd3990e82 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Wed, 16 May 2018 17:54:06 -0700 Subject: [PATCH 503/593] [SPARK-22210][ML] Add seed for LDA variationalTopicInference ## What changes were proposed in this pull request? - Add seed parameter for variationalTopicInference - Add seed for calling variationalTopicInference in submitMiniBatch - Add var seed in LDAModel so that it can take the seed from LDA and use it for the function call of variationalTopicInference in logLikelihoodBound, topicDistributions, getTopicDistributionMethod, and topicDistribution. ## How was this patch tested? Check the test result in mllib.clustering.LDASuite to make sure the result is repeatable with the seed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21183 from ludatabricks/SPARK-22210. --- .../org/apache/spark/ml/clustering/LDA.scala | 6 ++- .../spark/mllib/clustering/LDAModel.scala | 34 ++++++++++++--- .../spark/mllib/clustering/LDAOptimizer.scala | 42 +++++++++++-------- .../apache/spark/ml/clustering/LDASuite.scala | 6 +++ 4 files changed, 64 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index afe599cd167cb..fed42c959b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -569,10 +569,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248421..f915062d77389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f026b..f8e5f3ed76457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 8d728f063dd8c..4d848205034c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -253,6 +253,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") { From 9a641e7f721d01d283afb09dccefaf32972d3c04 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 17 May 2018 12:07:58 +0800 Subject: [PATCH 504/593] [SPARK-21945][YARN][PYTHON] Make --py-files work with PySpark shell in Yarn client mode ## What changes were proposed in this pull request? ### Problem When we run _PySpark shell with Yarn client mode_, specified `--py-files` are not recognised in _driver side_. Here are the steps I took to check: ```bash $ cat /home/spark/tmp.py def testtest(): return 1 ``` ```bash $ ./bin/pyspark --master yarn --deploy-mode client --py-files /home/spark/tmp.py ``` ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() # executor side [1] >>> test() # driver side Traceback (most recent call last): File "", line 1, in File "", line 2, in test ImportError: No module named tmp ``` ### How did it happen? Unlike Yarn cluster and client mode with Spark submit, when Yarn client mode with PySpark shell specifically, 1. It first runs Python shell via: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java#L158 as pointed out by tgravescs in the JIRA. 2. this triggers shell.py and submit another application to launch a py4j gateway: https://github.com/apache/spark/blob/209b9361ac8a4410ff797cff1115e1888e2f7e66/python/pyspark/java_gateway.py#L45-L60 3. it runs a Py4J gateway: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L425 4. it copies (or downloads) --py-files into local temp directory: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L365-L376 and then these files are set up to `spark.submit.pyFiles` 5. Py4J JVM is launched and then the Python paths are set via: https://github.com/apache/spark/blob/7013eea11cb32b1e0038dc751c485da5c94a484b/python/pyspark/context.py#L209-L216 However, these are not actually set because those files were copied into a tmp directory in 4. whereas this code path looks for `SparkFiles.getRootDirectory` where the files are stored only when `SparkContext.addFile()` is called. In other cluster mode, `spark.files` are set via: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L554-L555 and those files are explicitly added via: https://github.com/apache/spark/blob/ecb8b383af1cf1b67f3111c148229e00c9c17c40/core/src/main/scala/org/apache/spark/SparkContext.scala#L395 So we are fine in other modes. In case of Yarn client and cluster with _submit_, these are manually being handled. In particular https://github.com/apache/spark/pull/6360 added most of the logics. In this case, the Python path looks manually set via, for example, `deploy.PythonRunner`. We don't use `spark.files` here. ### How does the PR fix the problem? I tried to make an isolated approach as possible as I can: simply copy py file or zip files into `SparkFiles.getRootDirectory()` in driver side if not existing. Another possible way is to set `spark.files` but it does unnecessary stuff together and sounds a bit invasive. **Before** ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() [1] >>> test() Traceback (most recent call last): File "", line 1, in File "", line 2, in test ImportError: No module named tmp ``` **After** ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() [1] >>> test() 1 ``` ## How was this patch tested? I manually tested in standalone and yarn cluster with PySpark shell. .zip and .py files were also tested with the similar steps above. It's difficult to add a test. Author: hyukjinkwon Closes #21267 from HyukjinKwon/SPARK-21945. --- python/pyspark/context.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dbb463f6005a1..ede3b6af0a8cf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: - self._python_includes.append(filename) - sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + try: + filepath = os.path.join(SparkFiles.getRootDirectory(), filename) + if not os.path.exists(filepath): + # In case of YARN with shell mode, 'spark.submit.pyFiles' files are + # not added via SparkContext.addFile. Here we check if the file exists, + # try to copy and then add it to the path. See SPARK-21945. + shutil.copyfile(path, filepath) + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: + self._python_includes.append(filename) + sys.path.insert(1, filepath) + except Exception: + warnings.warn( + "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to " + "Python path:\n %s" % (path, "\n ".join(sys.path)), + RuntimeWarning) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) From 3e66350c2477a456560302b7738c9d122d5d9c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florent=20P=C3=A9pin?= Date: Thu, 17 May 2018 13:31:14 +0900 Subject: [PATCH 505/593] [SPARK-23925][SQL] Add array_repeat collection function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The PR adds a new collection function, array_repeat. As there already was a function repeat with the same signature, with the only difference being the expected return type (String instead of Array), the new function is called array_repeat to distinguish. The behaviour of the function is based on Presto's one. The function creates an array containing a given element repeated the requested number of times. ## How was this patch tested? New unit tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite Author: Florent Pépin Author: Florent Pépin Closes #21208 from pepinoflo/SPARK-23925. --- python/pyspark/sql/functions.py | 14 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 149 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 18 +++ .../org/apache/spark/sql/functions.scala | 20 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 76 +++++++++ 6 files changed, 278 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6866c1cf9f882..925ac34196f4c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2329,6 +2329,20 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_repeat(col, count): + """ + Collection function: creates an array containing a column repeated count times. + + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=[u'ab', u'ab', u'ab'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 087d000a9db70..9c370599bc0df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -427,6 +427,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 12b9ab2b272ab..2a4e42d4ba316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1468,3 +1468,152 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value + val count = rightGen.value + val et = dataType.elementType + + val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { + genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) + } else { + genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) + } + val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) + + ev.copy(code = + s""" + |boolean ${ev.isNull} = false; + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = + | ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin) + } + + private def nullElementsProtection( + ev: ExprCode, + rightIsNull: String, + coreLogic: String): String = { + if (nullable) { + s""" + |if ($rightIsNull) { + | ${ev.isNull} = true; + |} else { + | ${coreLogic} + |} + """.stripMargin + } else { + coreLogic + } + } + + private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { + val numElements = ctx.freshName("numElements") + val numElementsCode = + s""" + |int $numElements = 0; + |if ($count > 0) { + | $numElements = $count; + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (numElements, numElementsCode) + } + + private def genCodeForPrimitiveElement( + ctx: CodegenContext, + elementType: DataType, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val tempArrayDataName = ctx.freshName("tempArrayData") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val errorMessage = s" $prettyName failed." + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |if (!$leftIsNull) { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | } + |} else { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.setNullAt(k); + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForNonPrimitiveElement( + ctx: CodegenContext, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |if (!$leftIsNull) { + | for (int k = 0; k < $numElemName; k++) { + | $arrayName[k] = $element; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2851d071c7c6..57fc5f75dbca7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -468,4 +468,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asa3), null) checkEvaluation(Flatten(asa4), null) } + + test("ArrayRepeat") { + val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType)) + + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi")) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi")) + checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true)) + checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1)) + checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2)) + checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null)) + checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null)) + checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2))) + checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b71dfdad8aa9b..550571a61a036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3447,6 +3447,26 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(left: Column, right: Column): Column = withExpr { + ArrayRepeat(left.expr, right.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ecce06f4c0755..e26565cd153b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -843,6 +843,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array_repeat function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val strDF = Seq( + ("hi", 2), + (null, 2) + ).toDF("a", "b") + + val strDFTwiceResult = Seq( + Row(Seq("hi", "hi")), + Row(Seq(null, null)) + ) + + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + + val intDF = { + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType))) + val data = Seq( + Row(3, 2), + Row(null, 2) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val intDFTwiceResult = Seq( + Row(Seq(3, 3)), + Row(Seq(null, null)) + ) + + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + + val nullCountDF = { + val schema = StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))) + val data = Seq( + Row("hi", null), + Row(null, null) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq( + Row(null), + Row(null) + ) + ) + + // Error test cases + val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") + + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", $"b")) + } + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", lit("1"))) + } + intercept[AnalysisException] { + invalidTypeDF.selectExpr("array_repeat(a, 1.0)") + } + + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 6c35865d949a8b46f654cd53c7e5f3288def18d0 Mon Sep 17 00:00:00 2001 From: Artem Rudoy Date: Thu, 17 May 2018 18:49:46 +0800 Subject: [PATCH 506/593] [SPARK-22371][CORE] Return None instead of throwing an exception when an accumulator is garbage collected. ## What changes were proposed in this pull request? There's a period of time when an accumulator has been garbage collected, but hasn't been removed from AccumulatorContext.originals by ContextCleaner. When an update is received for such accumulator it will throw an exception and kill the whole job. This can happen when a stage completes, but there're still running tasks from other attempts, speculation etc. Since AccumulatorContext.get() returns an option we can just return None in such case. ## How was this patch tested? Unit test. Author: Artem Rudoy Closes #21114 from artemrd/SPARK-22371. --- .../org/apache/spark/util/AccumulatorV2.scala | 14 +++++++++----- .../scala/org/apache/spark/AccumulatorSuite.scala | 6 ++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 2bc84953a56eb..3b469a69437b9 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo private[spark] case class AccumulatorMetadata( @@ -211,7 +212,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { /** * An internal class used to track accumulators by Spark itself. */ -private[spark] object AccumulatorContext { +private[spark] object AccumulatorContext extends Logging { /** * This global map holds the original accumulator objects that are created on the driver. @@ -258,13 +259,16 @@ private[spark] object AccumulatorContext { * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ def get(id: Long): Option[AccumulatorV2[_, _]] = { - Option(originals.get(id)).map { ref => - // Since we are storing weak references, we must check whether the underlying data is valid. + val ref = originals.get(id) + if (ref eq null) { + None + } else { + // Since we are storing weak references, warn when the underlying data is not valid. val acc = ref.get if (acc eq null) { - throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") + logWarning(s"Attempted to access garbage collected accumulator $id") } - acc + Option(acc) } } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 3990ee1ec326d..5d0ffd92647bc 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex System.gc() assert(ref.get.isEmpty) - // Getting a garbage collected accum should throw error - intercept[IllegalStateException] { - AccumulatorContext.get(accId) - } + // Getting a garbage collected accum should return None. + assert(AccumulatorContext.get(accId).isEmpty) // Getting a normal accumulator. Note: this has to be separate because referencing an // accumulator above in an `assert` would keep it from being garbage collected. From 6ec05826d7b0a512847e2522564e01256c8d192d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 May 2018 20:42:40 +0800 Subject: [PATCH 507/593] [SPARK-24107][CORE][FOLLOWUP] ChunkedByteBuffer.writeFully method has not reset the limit value ## What changes were proposed in this pull request? According to the discussion in https://github.com/apache/spark/pull/21175 , this PR proposes 2 improvements: 1. add comments to explain why we call `limit` to write out `ByteBuffer` with slices. 2. remove the `try ... finally` ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21327 from cloud-fan/minor. --- .../spark/util/io/ChunkedByteBuffer.scala | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 3ae8dfcc1cb66..700ce56466c35 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,15 +63,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - val curChunkLimit = bytes.limit() + val originalLimit = bytes.limit() while (bytes.hasRemaining) { - try { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) - } finally { - bytes.limit(curChunkLimit) - } + // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct + // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread. + // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may + // cause significant native memory leak, if a large direct ByteBuffer is allocated and + // cached, as it's never released until thread exits. Here we write the `bytes` with + // fixed-size slices to limit the size of the cached direct ByteBuffer. + // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details. + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + bytes.limit(originalLimit) } } } From 69350aa2f0a7aee4dcb1067f073b61a0b9f9cb51 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 17 May 2018 20:45:32 +0800 Subject: [PATCH 508/593] [SPARK-23922][SQL] Add arrays_overlap function ## What changes were proposed in this pull request? The PR adds the function `arrays_overlap`. This function returns `true` if the input arrays contain a non-null common element; if not, it returns `null` if any of the arrays contains a `null` element, `false` otherwise. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21028 from mgaido91/SPARK-23922. --- python/pyspark/sql/functions.py | 15 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 267 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 66 +++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 29 ++ 6 files changed, 388 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 925ac34196f4c..8490081facc5a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1855,6 +1855,21 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + @since(2.4) def slice(x, start, length): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c370599bc0df..867c2d5eab53d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2a4e42d4ba316..c82db839438ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,15 +18,51 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + } + } +} + + /** * Given an array or map, returns its size. Returns -1 if null. */ @@ -529,6 +565,235 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Checks if the two arrays contain at least one common element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +// scalastyle:off line.size.limit +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryArrayExpressionWithImplicitCast { + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (RowOrdering.isOrderable(elementType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") + } + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + fastEval _ + } else { + bruteForceEval _ + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smaller.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smaller.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + smallestSet += v + }) + bigger.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else if (smallestSet.contains(v1)) { + return true + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0 && arr2.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val comparisonCode = if (elementTypeSupportEquals) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } + s""" + |ArrayData $smaller; + |ArrayData $bigger; + |if ($a1.numElements() > $a2.numElements()) { + | $bigger = $a1; + | $smaller = $a2; + |} else { + | $smaller = $a1; + | $bigger = $a2; + |} + |if ($smaller.numElements() > 0) { + | $comparisonCode + |} + """.stripMargin + }) + } + + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { + | $compareValues + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { + | $isInSmaller + |} + """.stripMargin + } + + def nullSafeElementCodegen( + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } + + override def prettyName: String = "arrays_overlap" +} + /** * Slices an array according to the requested start index and length */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 57fc5f75dbca7..6ae1ac18c4dc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) + + // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) + } + test("Slice") { val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 550571a61a036..2a8fe583b83bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3085,6 +3085,17 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + /** * Returns an array containing all the elements in `x` from index `start` (or starting from the * end if `start` is negative) with the specified `length`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e26565cd153b4..d08982a138bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) + + intercept[AnalysisException] { + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } + } + test("slice function") { val df = Seq( Seq(1, 2, 3), From 8a837bf4f3f2758f7825d2362cf9de209026651a Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 17 May 2018 22:29:18 +0800 Subject: [PATCH 509/593] [SPARK-24193] create TakeOrderedAndProjectExec only when the limit number is below spark.sql.execution.topKSortFallbackThreshold. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Physical plan of `select colA from t order by colB limit M` is `TakeOrderedAndProject`; Currently `TakeOrderedAndProject` sorts data in memory, see https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala#L158 We can add a config – if the number of limit (M) is too big, we can sort by disk. Thus memory issue can be resolved. ## How was this patch tested? Test added Author: jinxing Closes #21252 from jinxing64/SPARK-24193. --- .../org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 12 ++++++++---- .../apache/spark/sql/execution/PlannerSuite.scala | 12 ++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..2a673c6ce8f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1253,6 +1253,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .internal() + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .intConf + .createWithDefault(Int.MaxValue) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1424,6 +1433,8 @@ class SQLConf extends Serializable with Logging { def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 37a0b9d6c8728..b97a87a122406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,9 +66,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => // With whole stage codegen, Spark releases resources only when all the output data of the @@ -79,9 +81,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f0dfe6b76f7ae..a375f881c7d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -197,6 +197,18 @@ class PlannerSuite extends SharedSQLContext { assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } + test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { + val query0 = testData.select('value).orderBy('key).limit(100) + val planned0 = query0.queryExecution.executedPlan + assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + + val query1 = testData.select('value).orderBy('key).limit(2000) + val planned1 = query1.queryExecution.executedPlan + assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + } + } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { val query = testData.select('key, 'value).sort('key.desc).cache() assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) From a7a9b1837808b281f47643490abcf054f6de7b50 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 17 May 2018 11:13:16 -0700 Subject: [PATCH 510/593] [SPARK-24115] Have logging pass through instrumentation class. ## What changes were proposed in this pull request? Fixes to tuning instrumentation. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21340 from MrBago/tunning-instrumentation. --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++----- .../apache/spark/ml/tuning/TrainValidationSplit.scala | 10 +++++----- .../org/apache/spark/ml/util/Instrumentation.scala | 7 +++++++ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5e916cc4a9fdd..f327f37bad204 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() - logDebug(s"Train split $splitIndex with multiple sets of parameters.") + instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => @@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 13369c4df7180..14d6a69c36747 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } else None // Fit models in a Future for training in parallel - logDebug(s"Train split with multiple sets of parameters.") + instr.logDebug(s"Train split with multiple sets of parameters.") val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] @@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.unpersist() validationDataset.unpersist() - logInfo(s"Train validation split metrics: ${metrics.toSeq}") + instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best train validation split metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3247c394dfa64..467130b37c16e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -58,6 +58,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( s" storageLevel=${dataset.getStorageLevel}") } + /** + * Logs a debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Logs a warning message with a prefix that uniquely identifies the training session. */ From 439c69511812776cb4b82956547ce958d0669c52 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 17 May 2018 13:42:10 -0700 Subject: [PATCH 511/593] [SPARK-24114] Add instrumentation to FPGrowth. ## What changes were proposed in this pull request? Have FPGrowth keep track of model training using the Instrumentation class. ## How was this patch tested? manually Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21344 from MrBago/fpgrowth-instr. --- mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 0bf405d9abf9d..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -161,6 +161,8 @@ class FPGrowth @Since("2.2.0") ( private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val instr = Instrumentation.create(this, dataset) + instr.logParams(params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -183,7 +185,9 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + instr.logSuccess(model) + model } @Since("2.2.0") From d4a0895c628ca854895c3c35c46ed990af36ec61 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Thu, 17 May 2018 16:33:06 -0700 Subject: [PATCH 512/593] [SPARK-22884][ML] ML tests for StructuredStreaming: spark.ml.clustering ## What changes were proposed in this pull request? Converting clustering tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. This PR is a new version of https://github.com/apache/spark/pull/20319 Author: Sandor Murakozi Author: Joseph K. Bradley Closes #21358 from jkbradley/smurakozi-SPARK-22884. --- .../ml/clustering/BisectingKMeansSuite.scala | 41 ++++++++++--------- .../ml/clustering/GaussianMixtureSuite.scala | 22 ++++------ .../spark/ml/clustering/KMeansSuite.scala | 31 +++++++------- .../apache/spark/ml/clustering/LDASuite.scala | 21 ++++------ 4 files changed, 50 insertions(+), 65 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f3ff2afcad2cd..81842afbddbbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -19,17 +19,18 @@ package org.apache.spark.ml.clustering import scala.language.existentials -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset -class BisectingKMeansSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -68,10 +69,13 @@ class BisectingKMeansSuite // Verify fit does not fail on very sparse data val model = bkm.fit(sparseDataset) - val result = model.transform(sparseDataset) - val numClusters = result.select("prediction").distinct().collect().length - // Verify we hit the edge case - assert(numClusters < k && numClusters > 1) + + testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") { + rows => + val numClusters = rows.distinct.length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } } test("setter/getter") { @@ -104,19 +108,16 @@ class BisectingKMeansSuite val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } + // Check validity of model summary val numRows = dataset.count() assert(model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index d0d461a42711a..0b91f502f615b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -23,16 +23,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.stat.distribution.MultivariateGaussian -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} -class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { - import testImplicits._ +class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { + import GaussianMixtureSuite._ + import testImplicits._ final val k = 5 private val seed = 538009335 @@ -119,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.weights.length === k) assert(model.gaussians.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName, probabilityColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - // Check prediction matches the highest probability, and probabilities sum to one. - transformed.select(predictionColName, probabilityColName).collect().foreach { - case Row(pred: Int, prob: Vector) => + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName, probabilityColName) { + case Row(_, pred: Int, prob: Vector) => val probArray = prob.toArray val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2 assert(pred === predFromProb) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 680a7c2034083..2569e7a432ca4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,20 +22,21 @@ import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest - with PMMLReadWriteTest { +class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -109,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val model = kmeans.fit(dataset) assert(model.clusterCenters.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) @@ -149,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) - Seq(featuresColName, predictionColName).foreach { column => - assert(transformed.columns.contains(column)) - } + assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName)) assert(model.getFeaturesCol == featuresColName) assert(model.getPredictionCol == predictionColName) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 4d848205034c0..096b5416899e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -21,11 +21,9 @@ import scala.language.existentials import org.apache.hadoop.fs.Path -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ object LDASuite { @@ -61,7 +59,7 @@ object LDASuite { } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LDASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -186,16 +184,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.topicsMatrix.numCols === k) assert(!model.isDistributed) - // transform() - val transformed = model.transform(dataset) - val expectedColumns = Array("features", lda.getTopicDistributionCol) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - transformed.select(lda.getTopicDistributionCol).collect().foreach { r => - val topicDistribution = r.getAs[Vector](0) - assert(topicDistribution.size === k) - assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", lda.getTopicDistributionCol) { + case Row(_, topicDistribution: Vector) => + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) } // logLikelihood, logPerplexity From 7b2dca5b12164b787ec4e8e7e9f92c60a3f9563e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 18 May 2018 15:32:29 +0800 Subject: [PATCH 513/593] [SPARK-24277][SQL] Code clean up in SQL module: HadoopMapReduceCommitProtocol ## What changes were proposed in this pull request? In HadoopMapReduceCommitProtocol and FileFormatWriter, there are unnecessary settings in hadoop configuration. Also clean up some code in SQL module. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21329 from gengliangwang/codeCleanWrite. --- .../io/HadoopMapReduceCommitProtocol.scala | 15 +++------------ .../datasources/orc/OrcColumnVector.java | 6 +----- .../parquet/VectorizedRleValuesReader.java | 4 ++-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/views.scala | 10 ++++------ .../execution/datasources/FileFormatWriter.scala | 11 +++++------ .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- 7 files changed, 17 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3e60c50ada59b..163511b7ffa3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -145,18 +145,9 @@ class HadoopMapReduceCommitProtocol( } override def setupJob(jobContext: JobContext): Unit = { - // Setup IDs - val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) - val taskId = new TaskID(jobId, TaskType.MAP, 0) - val taskAttemptId = new TaskAttemptID(taskId, 0) - - // Set up the configuration object - jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) - jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) - jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) - jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) - jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) - + // Create a dummy [[TaskAttemptContextImpl]] with configuration to get [[OutputCommitter]] + // instance on Spark driver. Note that the job/task/attampt id doesn't matter here. + val taskAttemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) committer.setupJob(jobContext) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..fcf73e8d7ae6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -47,11 +47,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto OrcColumnVector(DataType type, ColumnVector vector) { super(type); - if (type instanceof TimestampType) { - isTimestamp = true; - } else { - isTimestamp = false; - } + isTimestamp = type instanceof TimestampType; baseData = vector; if (vector instanceof LongColumnVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fe3d31ae8e746..de0d65a1e0906 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -571,7 +571,7 @@ private int readIntLittleEndian() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); } /** @@ -592,7 +592,7 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); + return (ch1 << 16) + (ch2 << 8) + (ch3); } case 4: { return readIntLittleEndian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index af20764f9a968..265a84b39a425 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -188,7 +188,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] + val fields = SerDe.readList(dis, jvmObjectTracker = null) Row.fromSeq(fields) case _ => null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 5172f32ec7b9c..6373584b10e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -410,12 +410,10 @@ object ViewHelper { } // Detect cyclic references from subqueries. - plan.expressions.foreach { expr => - expr match { - case s: SubqueryExpression => - checkCyclicViewReference(s.plan, path, viewIdent) - case _ => // Do nothing. - } + plan.expressions.foreach { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..681bb1df6bbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -244,18 +244,17 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) - val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) - val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the attempt context required to use in the output committer. val taskAttemptContext: TaskAttemptContext = { + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) - hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -378,7 +377,7 @@ object FileFormatWriter extends Logging { dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.map(_.newFile(currentPath)) + statsTrackers.foreach(_.newFile(currentPath)) } override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { @@ -429,10 +428,10 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty + private val isPartitioned = desc.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined + private val isBucketed = desc.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d254af400a7cf..2c4d0bcf103ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -170,7 +170,7 @@ class SQLAppStatusListener( .filter { case (id, _) => metricIds.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) } // Check the execution again for whether the aggregated metrics data has been calculated. From 0cf59fcbe3799dd3c4469cbf8cd842d668a76f34 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 18 May 2018 09:53:24 -0700 Subject: [PATCH 514/593] [SPARK-24303][PYTHON] Update cloudpickle to v0.4.4 ## What changes were proposed in this pull request? cloudpickle 0.4.4 is released - https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.4 There's no invasive change - the main difference is that we are now able to pickle the root logger, which fix is pretty isolated. ## How was this patch tested? Jenkins tests. Author: hyukjinkwon Closes #21350 from HyukjinKwon/SPARK-24303. --- python/pyspark/cloudpickle.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index ea845b98b3db2..88519d7311fcc 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -272,7 +272,7 @@ def save_memoryview(self, obj): if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) @@ -801,10 +801,10 @@ def save_ellipsis(self, obj): def save_not_implemented(self, obj): self.save_reduce(_gen_not_implemented, ()) - if PY3: - dispatch[io.TextIOWrapper] = save_file - else: + try: # Python 2 dispatch[file] = save_file + except NameError: # Python 3 + dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented @@ -819,6 +819,11 @@ def save_logger(self, obj): dispatch[logging.Logger] = save_logger + def save_root_logger(self, obj): + self.save_reduce(logging.getLogger, (), obj=obj) + + dispatch[logging.RootLogger] = save_root_logger + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" From 7696b9de0df6e9eb85a74bdb404409da693cf65e Mon Sep 17 00:00:00 2001 From: Soham Aurangabadkar Date: Fri, 18 May 2018 10:29:34 -0700 Subject: [PATCH 515/593] [SPARK-20538][SQL] Wrap Dataset.reduce with withNewRddExecutionId. ## What changes were proposed in this pull request? Wrap Dataset.reduce with `withNewExecutionId`. Author: Soham Aurangabadkar Closes #21316 from sohama4/dataset_reduce_withexecutionid. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f001f16e1d5ee..32267eb0300f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1617,7 +1617,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: From 807ba44cb742c5f7c22bdf6bfe2cf814be85398e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 May 2018 10:35:43 -0700 Subject: [PATCH 516/593] [SPARK-24159][SS] Enable no-data micro batches for streaming mapGroupswithState ## What changes were proposed in this pull request? Enabled no-data batches in flatMapGroupsWithState in following two cases. - When ProcessingTime timeout is used, then we always run a batch every trigger interval. - When event-time watermark is defined, then the user may be doing arbitrary logic against the watermark value even if timeouts are not set. In such cases, it's best to run batches whenever the watermark has changed, irrespective of whether timeouts (i.e. event-time timeout) have been explicitly enabled. ## How was this patch tested? updated tests Author: Tathagata Das Closes #21345 from tdas/SPARK-24159. --- .../FlatMapGroupsWithStateExec.scala | 17 ++- .../FlatMapGroupsWithStateSuite.scala | 120 ++++++++++-------- 2 files changed, 80 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 80769d728b8f1..8e82cccbc8fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec( case _ => iter } - // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all @@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timingOutPairs = store.getRange(None, None).filter { rowPair => val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => + timingOutPairs.flatMap { rowPair => callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b1416bff87ee7..988c8e6753e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -615,20 +615,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -657,15 +657,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) + CheckNewAnswer(("a", "1"), ("c", "1")) ) } @@ -694,22 +694,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Complete)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckNewAnswer(("a", 1)), AddData(inputData, "a", "b"), // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckNewAnswer(("a", 2), ("b", 1)), StopStream, StartStream(), AddData(inputData, "a", "b"), // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + CheckNewAnswer(("a", 3), ("b", 2)), StopStream, StartStream(), AddData(inputData, "a", "c"), // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) ) } @@ -729,8 +729,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } @@ -757,17 +757,17 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "b"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("b", "1")), + CheckNewAnswer(("b", "1")), assertNumStateRows(total = 2, updated = 1), AddData(inputData, "b"), AdvanceManualClock(10 * 1000), - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, @@ -775,38 +775,42 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "c"), AdvanceManualClock(11 * 1000), - CheckLastBatch(("b", "-1"), ("c", "1")), + CheckNewAnswer(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), - AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), - CheckLastBatch(("c", "2")), - assertNumStateRows(total = 1, updated = 1) + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows(total = 0, updated = 0) ) } test("flatMapGroupsWithState - streaming with event time timeout + watermark") { - // Function to maintain the max event time - // Returns the max event time in the state, or -1 if the state was removed by timeout + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - val timeoutDelay = 5 - if (key != "a") { - Iterator.empty + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) } else { - if (state.hasTimedOut) { - state.remove() - Iterator((key, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) - val timeoutTimestampMs = maxEventTime + timeoutDelay - state.update(maxEventTime) - state.setTimeoutTimestamp(timeoutTimestampMs * 1000) - Iterator((key, maxEventTime.toInt)) - } + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) } } val inputData = MemoryStream[(String, Int)] @@ -819,15 +823,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second")), - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... - CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckLastBatch(), // No output as data should get filtered by watermark - AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s - CheckLastBatch(), // No output as no data for "a" - AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored - CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 ) } @@ -856,20 +868,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -920,15 +932,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), + CheckNewAnswer(("a", 1L)), AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), + CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), ExpectFailure[SparkException](), // task should fail but should not increment count setFailInTask(false), StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count ) } @@ -938,7 +950,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch("a"), + CheckNewAnswer("a"), AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) ) } @@ -1000,7 +1012,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")) + CheckNewAnswer(("a", "1")) ) } } From ed7ba7db8fa344ff182b72d23ae458e711f63432 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 18 May 2018 11:14:22 -0700 Subject: [PATCH 517/593] [SPARK-23850][SQL] Add separate config for SQL options redaction. The old code was relying on a core configuration and extended its default value to include things that redact desired things in the app's environment. Instead, add a SQL-specific option for which options to redact, and apply both the core and SQL-specific rules when redacting the options in the save command. This is a little sub-optimal since it adds another config, but it retains the current default behavior. While there I also fixed a typo and a couple of minor config API usage issues in the related redaction option that SQL already had. Tested with existing unit tests, plus checking the env page on a shell UI. Author: Marcelo Vanzin Closes #21158 from vanzin/SPARK-23850. --- .../spark/internal/config/package.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 24 +++++++++++++++++-- .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../SaveIntoDataSourceCommand.scala | 5 ++-- .../SaveIntoDataSourceCommandSuite.scala | 3 --- 6 files changed, 27 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 82f0a04e94b1c..a54b091a64d50 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -342,7 +342,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password|url|user|username".r) + .createWithDefault("(?i)secret|password".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2a673c6ce8f4a..53a50305348fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1155,8 +1155,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = + buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .regexConf + .createWithDefault("(?i)url".r) + val SQL_STRING_REDACTION_PATTERN = - ConfigBuilder("spark.sql.redaction.string.regex") + buildConf("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + "information. When this regex matches a string part, that string part is replaced by a " + "dummy value. This is currently used to redact the output of SQL explain commands. " + @@ -1429,7 +1438,7 @@ class SQLConf extends Serializable with Logging { def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) - def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) @@ -1738,6 +1747,17 @@ class SQLConf extends Serializable with Logging { }.toSeq } + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions(options: Map[String, String]): Map[String, String] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap + } + /** * Return whether a given key is set in this [[SQLConf]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..61c14fee09337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 15379a0663f7d..3112b306c365e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -225,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * Redact the sensitive information in the given string. */ private def withRedaction(message: String): String = { - Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message) } /** A special namespace for commands that can be used to debug query execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 568e953a5db66..00b1b5dedb593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.CreatableRelationProvider -import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand( } override def simpleString: String = { - val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 4b3ca8e60cab6..a1da3ec43eae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class SaveIntoDataSourceCommandSuite extends SharedSQLContext { - override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.redaction.regex", "(?i)password|url") - test("simpleString is redacted") { val URL = "connection.url" val PASS = "123" From 1c4553d67de8089e8aa84bc736faa11f21615a6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 18 May 2018 12:51:09 -0700 Subject: [PATCH 518/593] Revert "[SPARK-24277][SQL] Code clean up in SQL module: HadoopMapReduceCommitProtocol" This reverts commit 7b2dca5b12164b787ec4e8e7e9f92c60a3f9563e. --- .../io/HadoopMapReduceCommitProtocol.scala | 15 ++++++++++++--- .../datasources/orc/OrcColumnVector.java | 6 +++++- .../parquet/VectorizedRleValuesReader.java | 4 ++-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/views.scala | 10 ++++++---- .../execution/datasources/FileFormatWriter.scala | 11 ++++++----- .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- 7 files changed, 33 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 163511b7ffa3a..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -145,9 +145,18 @@ class HadoopMapReduceCommitProtocol( } override def setupJob(jobContext: JobContext): Unit = { - // Create a dummy [[TaskAttemptContextImpl]] with configuration to get [[OutputCommitter]] - // instance on Spark driver. Note that the job/task/attampt id doesn't matter here. - val taskAttemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + // Setup IDs + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) + jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) + jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) committer.setupJob(jobContext) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index fcf73e8d7ae6c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -47,7 +47,11 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto OrcColumnVector(DataType type, ColumnVector vector) { super(type); - isTimestamp = type instanceof TimestampType; + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } baseData = vector; if (vector instanceof LongColumnVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index de0d65a1e0906..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -571,7 +571,7 @@ private int readIntLittleEndian() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** @@ -592,7 +592,7 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return (ch1 << 16) + (ch2 << 8) + (ch3); + return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { return readIntLittleEndian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 265a84b39a425..af20764f9a968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -188,7 +188,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis, jvmObjectTracker = null) + val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] Row.fromSeq(fields) case _ => null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 6373584b10e35..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -410,10 +410,12 @@ object ViewHelper { } // Detect cyclic references from subqueries. - plan.expressions.foreach { - case s: SubqueryExpression => - checkCyclicViewReference(s.plan, path, viewIdent) - case _ => // Do nothing. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 681bb1df6bbae..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -244,17 +244,18 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the attempt context required to use in the output committer. val taskAttemptContext: TaskAttemptContext = { - val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) - val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -377,7 +378,7 @@ object FileFormatWriter extends Logging { dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.foreach(_.newFile(currentPath)) + statsTrackers.map(_.newFile(currentPath)) } override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { @@ -428,10 +429,10 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { /** Flag saying whether or not the data to be written out is partitioned. */ - private val isPartitioned = desc.partitionColumns.nonEmpty + val isPartitioned = desc.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = desc.bucketIdExpression.isDefined + val isBucketed = desc.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2c4d0bcf103ff..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -170,7 +170,7 @@ class SQLAppStatusListener( .filter { case (id, _) => metricIds.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) } // Check the execution again for whether the aggregated metrics data has been calculated. From 7f82c4a47e94ee4f544dee8bb71b99534e919769 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 May 2018 12:54:19 -0700 Subject: [PATCH 519/593] [SPARK-24312][SQL] Upgrade to 2.3.3 for Hive Metastore Client 2.3 ## What changes were proposed in this pull request? Hive 2.3.3 was [released on April 3rd](https://issues.apache.org/jira/secure/ReleaseNote.jspa?version=12342162&styleName=Text&projectId=12310843). This PR aims to upgrade Hive Metastore Client 2.3 from 2.3.2 to 2.3.3. ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #21359 from dongjoon-hyun/SPARK-24312. --- docs/sql-programming-guide.md | 4 ++-- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/client/package.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f79ed6422205..b93d8531d9efe 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used
    @@ -2237,7 +2237,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bb134bbe68bd9..cd321d41f43e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.2.") + s"0.12.0 through 2.3.3.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c2690ec32b9e7..2f34f69b5cf48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -98,7 +98,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 681ee9200f02b..25e9886fa6576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.2", + case object v2_3 extends HiveVersion("2.3.3", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) From 3159ee085b23e2e9f1657d80b7ae3efe82b5edb9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 13:04:00 -0700 Subject: [PATCH 520/593] [SPARK-24149][YARN] Retrieve all federated namespaces tokens ## What changes were proposed in this pull request? Hadoop 3 introduces HDFS federation. This means that multiple namespaces are allowed on the same HDFS cluster. In Spark, we need to ask the delegation token for all the namenodes (for each namespace), otherwise accessing any other namespace different from the default one (for which we already fetch the delegation token) fails. The PR adds the automatic discovery of all the namenodes related to all the namespaces available according to the configs in hdfs-site.xml. ## How was this patch tested? manual tests in dockerized env Author: Marco Gaido Closes #21216 from mgaido91/SPARK-24149. --- docs/running-on-yarn.md | 9 ++- .../deploy/yarn/YarnSparkHadoopUtil.scala | 24 ++++++- .../yarn/YarnSparkHadoopUtilSuite.scala | 65 ++++++++++++++++++- 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c9e68c3bfd056..4dbcbeafbbd9d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -424,9 +424,12 @@ To use a custom metrics.properties for the application master and executors, upd Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. -In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home -directory, Spark will also automatically obtain delegation tokens for the service hosting the -staging directory of the Spark application. +In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens +for: + +- the filesystem hosting the staging directory of the Spark application (which is the default + filesystem if `spark.yarn.stagingDir` is not set); +- if Hadoop federation is enabled, all the federated filesystems in the configuration. If an application needs to interact with other secure Hadoop filesystems, their URIs need to be explicitly provided to Spark at launch time. This is done by listing them in the diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 8eda6cb1277c5..7250e58b6c49a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -200,7 +200,29 @@ object YarnSparkHadoopUtil { .map(new Path(_).getFileSystem(hadoopConf)) .getOrElse(FileSystem.get(hadoopConf)) - filesystemsToAccess + stagingFS + // Add the list of available namenodes for all namespaces in HDFS federation. + // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its + // namespaces. + val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { + Set.empty + } else { + val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") + // Retrieving the filesystem for the nameservices where HA is not enabled + val filesystemsWithoutHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode => + new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf) + } + } + // Retrieving the filesystem for the nameservices where HA is enabled + val filesystemsWithHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ => + new Path(s"hdfs://$ns").getFileSystem(hadoopConf) + } + } + (filesystemsWithoutHA ++ filesystemsWithHA).toSet + } + + filesystemsToAccess ++ hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index f21353aa007c8..61c0c43f7c04f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -21,7 +21,8 @@ import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.io.Text +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } + test("SPARK-24149: retrieve all namenodes from HDFS") { + val sparkConf = new SparkConf() + val basicFederationConf = new Configuration() + basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + basicFederationConf.set("dfs.nameservices", "ns1,ns2") + basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020") + basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021") + val basicFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf), + new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf)) + val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, basicFederationConf) + basicFederationResult should be (basicFederationExpected) + + // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them + val viewFsConf = new Configuration() + viewFsConf.addResource(basicFederationConf) + viewFsConf.set("fs.defaultFS", "viewfs://clusterX/") + viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/") + val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf)) + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected) + + // invalid config should not throw NullPointerException + val invalidFederationConf = new Configuration() + invalidFederationConf.addResource(basicFederationConf) + invalidFederationConf.unset("dfs.namenode.rpc-address.ns2") + val invalidFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf)) + val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, invalidFederationConf) + invalidFederationResult should be (invalidFederationExpected) + + // no namespaces defined, ie. old case + val noFederationConf = new Configuration() + noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + val noFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(noFederationConf)) + val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf) + noFederationResult should be (noFederationExpected) + + // federation and HA enabled + val federationAndHAConf = new Configuration() + federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA") + federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA") + federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2") + federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + + val federationAndHAExpected = Set( + new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf), + new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf)) + val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, federationAndHAConf) + federationAndHAResult should be (federationAndHAExpected) + } } From a53ea70c1d8903cdff051edf667b0127c8131a09 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 18 May 2018 13:38:36 -0700 Subject: [PATCH 521/593] [SPARK-23856][SQL] Add an option `queryTimeout` in JDBCOptions ## What changes were proposed in this pull request? This pr added an option `queryTimeout` for the number of seconds the the driver will wait for a Statement object to execute. ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro Closes #21173 from maropu/SPARK-23856. --- docs/sql-programming-guide.md | 11 +++++++++++ .../org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../datasources/jdbc/JDBCOptions.scala | 5 +++++ .../execution/datasources/jdbc/JDBCRDD.scala | 3 +++ .../jdbc/JdbcRelationProvider.scala | 2 +- .../execution/datasources/jdbc/JdbcUtils.scala | 16 +++++++++++++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 16 ++++++++++++++++ .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 18 ++++++++++++++++++ 8 files changed, 69 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b93d8531d9efe..f1ed316341b95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1338,6 +1338,17 @@ the following case-insensitive options: + + + + +
    `spark.mllib` modelPMML model
    spark.mllib modelPMML model
    This configuration limits the number of remote requests to fetch blocks at any given point. When the number of hosts in the cluster increase, it might lead to very large number - of in-bound connections to one or more nodes, causing the workers to fail under load. + of inbound connections to one or more nodes, causing the workers to fail under load. By allowing it to limit the number of fetch requests, this scenario can be mitigated.
    4194304 (4 MB) The estimated cost to open a file, measured by the number of bytes could be scanned at the same - time. This is used when putting multiple files into a partition. It is better to over estimate, + time. This is used when putting multiple files into a partition. It is better to overestimate, then the partitions with small files will be faster than partitions with bigger files.
    0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained + (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarse-grained mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, @@ -1634,7 +1634,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + failure happens. If external shuffle service is enabled, then the whole node will be blacklisted.
    spark.streaming.receiver.writeAheadLog.enable false - Enable write ahead logs for receivers. All the input data received through receivers - will be saved to write ahead logs that will allow it to be recovered after driver failures. + Enable write-ahead logs for receivers. All the input data received through receivers + will be saved to write-ahead logs that will allow it to be recovered after driver failures. See the deployment guide in the Spark Streaming programing guide for more details. spark.streaming.driver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the driver. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the metadata WAL on the driver. spark.streaming.receiver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the receivers. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the data WAL on the receivers. spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + Attribute-based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting applies only to executors. Refer to Mesos Attributes & Resources for more information on attributes.
      diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e07759a4dba87..ceda8a3ae2403 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -418,7 +418,7 @@ To use a custom metrics.properties for the application master and executors, upd - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example, you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. # Kerberos diff --git a/docs/security.md b/docs/security.md index 3e5607a9a0d67..8c0c66fb5a285 100644 --- a/docs/security.md +++ b/docs/security.md @@ -374,7 +374,7 @@ replaced with one of the above namespaces.
    ${ns}.enabledAlgorithms None - A comma separated list of ciphers. The specified ciphers must be supported by JVM. + A comma-separated list of ciphers. The specified ciphers must be supported by JVM.
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section of the Java security guide. The list for Java 8 can be found at diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 8fa643abf1373..f06e72a387df1 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -338,7 +338,7 @@ worker during one single schedule iteration. # Monitoring and Logging -Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. +Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. diff --git a/docs/sparkr.md b/docs/sparkr.md index 2909247e79e95..7fabab5d38f16 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -107,7 +107,7 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically, we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
    {% highlight r %} @@ -169,7 +169,7 @@ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings {% endhighlight %}
    -The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example +The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example, we can save the SparkDataFrame from the previous example to a Parquet file using `write.df`.
    @@ -241,7 +241,7 @@ head(filter(df, df$waiting < 50)) ### Grouping, Aggregation -SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example, we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
    {% highlight r %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9822d669050d5..55d35b9dd31db 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -165,7 +165,7 @@ In addition to simple column references and expressions, Datasets also have a ri
    -In Python it's possible to access a DataFrame's columns either by attribute +In Python, it's possible to access a DataFrame's columns either by attribute (`df.age`) or by indexing (`df['age']`). While the former is convenient for interactive data exploration, users are highly encouraged to use the latter form, which is future proof and won't break with column names that @@ -278,7 +278,7 @@ the bytes back into an object. Spark SQL supports two different methods for converting existing RDDs into Datasets. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This -reflection based approach leads to more concise code and works well when you already know the schema +reflection-based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating Datasets is through a programmatic interface that allows you to @@ -1243,7 +1243,7 @@ The following options can be used to configure the version of Hive that is used
    com.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc

    - A comma separated list of class prefixes that should be loaded using the classloader that is + A comma-separated list of class prefixes that should be loaded using the classloader that is shared between Spark SQL and a specific version of Hive. An example of classes that should be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need to be shared are those that interact with classes that are already shared. For example, @@ -1441,7 +1441,7 @@ SELECT * FROM resultTable # Performance Tuning -For some workloads it is possible to improve performance by either caching data in memory, or by +For some workloads, it is possible to improve performance by either caching data in memory, or by turning on some experimental options. ## Caching Data In Memory @@ -1804,7 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. @@ -1966,11 +1966,11 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. ## Upgrading From Spark SQL 2.1 to 2.2 - - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). @@ -2013,7 +2013,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC connection owns a copy of their own SQL configuration and temporary function registry. Cached tables are still shared though. If you prefer to run the Thrift server in the old single-session mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add @@ -2161,7 +2161,7 @@ been renamed to `DataFrame`. This is primarily because DataFrames no longer inhe directly, but instead provide most of the functionality that RDDs provide though their own implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. -In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. @@ -2170,11 +2170,11 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users of either language should use `SQLContext` and `DataFrame`. In general these classes try to -use types that are usable from both languages (i.e. `Array` instead of language specific collections). +use types that are usable from both languages (i.e. `Array` instead of language-specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally, the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. @@ -2231,7 +2231,7 @@ referencing a singleton. ## Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. -Currently Hive SerDes and UDFs are based on Hive 1.2.1, +Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore (from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). @@ -2323,10 +2323,10 @@ A handful of Hive optimizations are not yet included in Spark. Some of these (su less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. -* Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you +* Block-level bitmap indexes and virtual columns (used to build indexes) +* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". -* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still +* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. * `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. @@ -2983,6 +2983,6 @@ does not exactly match standard floating point semantics. Specifically: - NaN = NaN returns true. - - In aggregations all NaN values are grouped together. + - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index 1dd54719b21aa..dacaa3438d489 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -39,7 +39,7 @@ For example, for Maven support, add the following to the pom.xml fi # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -The main category of parameters that should be configured are the authentication parameters +The main category of parameters that should be configured is the authentication parameters required by Keystone. The following table contains a list of Keystone mandatory parameters. PROVIDER can be diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 257a4f7d4f3ca..a1b6942ffe0a4 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -17,7 +17,7 @@ Choose a machine in your cluster such that - Flume can be configured to push data to a port on that machine. -Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data. #### Configuring Flume Configure Flume agent to send data to an Avro sink by having the following in the configuration file. @@ -100,7 +100,7 @@ Choose a machine that will run the custom sink in a Flume agent. The rest of the #### Configuring Flume Configuring Flume on the chosen machine requires the following two steps. -1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink. (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). @@ -128,7 +128,7 @@ Configuring Flume on the chosen machine requires the following two steps. agent.sinks.spark.port = agent.sinks.spark.channel = memoryChannel - Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about configuring Flume agents. diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md index 9f0671da2ee31..becf217738d26 100644 --- a/docs/streaming-kafka-0-8-integration.md +++ b/docs/streaming-kafka-0-8-integration.md @@ -10,7 +10,7 @@ Here we explain how to configure Spark Streaming to receive data from Kafka. The ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write-Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write-ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write-Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -55,11 +55,11 @@ Next, we discuss how to use this approach in your streaming application. **Points to remember:** - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + - Topic partitions in Kafka do not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write-Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). 3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. @@ -80,9 +80,9 @@ This approach has the following advantages over the receiver-based approach (i.e - *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write-Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write-Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write-Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high-level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with-write-ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ffda36d64a770..c30959263cdfa 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1461,7 +1461,7 @@ Note that the connections in the pool should be lazily created on demand and tim *** ## DataFrame and SQL Operations -You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore, this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.

    @@ -2010,10 +2010,10 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *Configuring write ahead logs* - Since Spark 1.2, - we have introduced _write ahead logs_ for achieving strong +- *Configuring write-ahead logs* - Since Spark 1.2, + we have introduced _write-ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into - a write ahead log in the configuration checkpoint directory. This prevents data loss on driver + a write-ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) @@ -2021,15 +2021,15 @@ To run a Spark Streaming applications, you need to have the following. come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the - received data within Spark be disabled when the write ahead log is enabled as the log is already + received data within Spark be disabled when the write-ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that - does not support flushing) for _write ahead logs_, please remember to enable + does not support flushing) for _write-ahead logs_, please remember to enable `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - Note that Spark will not encrypt data written to the write ahead log when I/O encryption is - enabled. If encryption of the write ahead log data is desired, it should be stored in a file + Note that Spark will not encrypt data written to the write-ahead log when I/O encryption is + enabled. If encryption of the write-ahead log data is desired, it should be stored in a file system that supports encryption natively. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming @@ -2284,9 +2284,9 @@ Having bigger blockinterval means bigger blocks. A high value of `spark.locality - Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued. -- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted. +- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However, the partitioning of the RDDs is not impacted. -- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. +- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently, there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. *************************************************************************************************** @@ -2388,7 +2388,7 @@ then besides these losses, all of the past data that was received and replicated lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +ahead logs_ which save the received data to fault-tolerant storage. With the [write-ahead logs enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2402,7 +2402,7 @@ The following table summarizes the semantics under failures:
    Spark 1.1 or earlier, OR
    - Spark 1.2 or later without write ahead logs + Spark 1.2 or later without write-ahead logs
    Buffered data lost with unreliable receivers
    @@ -2416,7 +2416,7 @@ The following table summarizes the semantics under failures:
    Spark 1.2 or later with write ahead logsSpark 1.2 or later with write-ahead logs Zero data loss with reliable receivers
    At-least once semantics diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 5647ec6bc5797..71fd5b10cc407 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -15,7 +15,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also see the [Deploying](#deploying) subsection below. +For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also, see the [Deploying](#deploying) subsection below. ## Reading Data from Kafka diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 9a83f157452ad..602a4c70848e7 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write-Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. @@ -479,7 +479,7 @@ detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) -to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. +to track the read position in the stream. The engine uses checkpointing and write-ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` @@ -690,7 +690,7 @@ These examples generate streaming DataFrames that are untyped, meaning that the By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. -Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user-provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. @@ -2661,7 +2661,7 @@ sql("SET spark.sql.streaming.metricsEnabled=true") All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write-ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index a3643bf0838a1..77aa083c4a584 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,7 +177,7 @@ The master URL passed to Spark can be in one of the following formats: # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a -properties file and pass them on to your application. By default it will read options +properties file and pass them on to your application. By default, it will read options from `conf/spark-defaults.conf` in the Spark directory. For more detail, see the section on [loading default configurations](configuration.html#loading-default-configurations). diff --git a/docs/tuning.md b/docs/tuning.md index fc27713f28d46..912c39879be8f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -196,7 +196,7 @@ To further tune garbage collection, we first need to understand some basic infor * A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old - enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked. + enough or Survivor2 is full, it is moved to Old. Finally, when Old is close to full, a full GC is invoked. The goal of GC tuning in Spark is to ensure that only long-lived RDDs are stored in the Old generation and that the Young generation is sufficiently sized to store short-lived objects. This will help avoid full GCs to collect diff --git a/python/README.md b/python/README.md index 3f17fdb98a081..2e0112da58b94 100644 --- a/python/README.md +++ b/python/README.md @@ -22,7 +22,7 @@ This packaging is currently experimental and may change in future versions (alth Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to set up your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). **NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. diff --git a/sql/README.md b/sql/README.md index fe1d352050c09..70cc7c637b58d 100644 --- a/sql/README.md +++ b/sql/README.md @@ -6,7 +6,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`. From 94524019315ad463f9bc13c107131091d17c6af9 Mon Sep 17 00:00:00 2001 From: Yuchen Huo Date: Fri, 6 Apr 2018 08:35:20 -0700 Subject: [PATCH 264/593] [SPARK-23822][SQL] Improve error message for Parquet schema mismatches ## What changes were proposed in this pull request? This pull request tries to improve the error message for spark while reading parquet files with different schemas, e.g. One with a STRING column and the other with a INT column. A new ParquetSchemaColumnConvertNotSupportedException is added to replace the old UnsupportedOperationException. The Exception is again wrapped in FileScanRdd.scala to throw a more a general QueryExecutionException with the actual parquet file name which trigger the exception. ## How was this patch tested? Unit tests added to check the new exception and verify the error messages. Also manually tested with two parquet with different schema to check the error message. screen shot 2018-03-30 at 4 03 04 pm Author: Yuchen Huo Closes #20953 from yuchenhuo/SPARK-23822. --- ...emaColumnConvertNotSupportedException.java | 62 +++++++++++++++++++ .../parquet/VectorizedColumnReader.java | 38 ++++++++---- .../execution/QueryExecutionException.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 21 ++++++- .../parquet/ParquetSchemaSuite.scala | 55 ++++++++++++++++ 5 files changed, 166 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java new file mode 100644 index 0000000000000..82a1169cbe7ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * Exception thrown when the parquet reader find column type mismatches. + */ +@InterfaceStability.Unstable +public class SchemaColumnConvertNotSupportedException extends RuntimeException { + + /** + * Name of the column which cannot be converted. + */ + private String column; + /** + * Physical column type in the actual parquet file. + */ + private String physicalType; + /** + * Logical column type in the parquet schema the parquet reader use to parse all files. + */ + private String logicalType; + + public String getColumn() { + return column; + } + + public String getPhysicalType() { + return physicalType; + } + + public String getLogicalType() { + return logicalType; + } + + public SchemaColumnConvertNotSupportedException( + String column, + String physicalType, + String logicalType) { + super(); + this.column = column; + this.physicalType = physicalType; + this.logicalType = logicalType; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 47dd625f4b154..72f1d024b08ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.Arrays; import java.util.TimeZone; import org.apache.parquet.bytes.BytesUtils; @@ -31,6 +32,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -231,6 +233,18 @@ private boolean shouldConvertTimestamps() { return convertTz != null && !convertTz.equals(UTC); } + /** + * Helper function to construct exception for parquet schema mismatch. + */ + private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException( + ColumnDescriptor descriptor, + WritableColumnVector column) { + return new SchemaColumnConvertNotSupportedException( + Arrays.toString(descriptor.getPath()), + descriptor.getType().toString(), + column.dataType().toString()); + } + /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ @@ -261,7 +275,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -282,7 +296,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -321,7 +335,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; case BINARY: @@ -360,7 +374,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -375,7 +389,9 @@ private void decodeDictionaryIds( */ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { - assert(column.dataType() == DataTypes.BooleanType); + if (column.dataType() != DataTypes.BooleanType) { + throw constructConvertNotSupportedException(descriptor, column); + } defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } @@ -394,7 +410,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -414,7 +430,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -425,7 +441,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { defColumn.readFloats( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -436,7 +452,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { defColumn.readDoubles( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -471,7 +487,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -510,7 +526,7 @@ private void readFixedLenByteArrayBatch( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala index 16806c620635f..cffd97baea6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -17,4 +17,5 @@ package org.apache.spark.sql.execution -class QueryExecutionException(message: String) extends Exception(message) +class QueryExecutionException(message: String, cause: Throwable = null) + extends Exception(message, cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 835ce98462477..28c36b6020d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,11 +21,14 @@ import java.io.{FileNotFoundException, IOException} import scala.collection.mutable +import org.apache.parquet.io.ParquetDecodingException + import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator @@ -179,7 +182,23 @@ class FileScanRDD( currentIterator = readCurrentFile() } - hasNext + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } } else { currentFile = null InputFileBlockHolder.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 2cd2a600f2b97..9d3dfae348beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.parquet.io.ParquetDecodingException import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -382,6 +385,58 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + // ======================================= + // Tests for parquet schema mismatch error + // ======================================= + def testSchemaMismatch(path: String, vectorizedReaderEnabled: Boolean): SparkException = { + import testImplicits._ + + var e: SparkException = null + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReaderEnabled.toString) { + // Create two parquet files with different schemas in the same folder + Seq(("bcd", 2)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet(s"$path/parquet") + Seq((1, "abc")).toDF("a", "b").coalesce(1).write.mode("append").parquet(s"$path/parquet") + + e = intercept[SparkException] { + spark.read.parquet(s"$path/parquet").collect() + } + } + e + } + + test("schema mismatch failure error message for parquet reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) + val expectedMessage = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the corresponding " + + "files. Details:" + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[ParquetDecodingException]) + assert(e.getCause.getMessage.startsWith(expectedMessage)) + } + } + + test("schema mismatch failure error message for parquet vectorized reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true) + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + + // Check if the physical type is reporting correctly + val errMsg = e.getCause.getMessage + assert(errMsg.startsWith("Parquet column cannot be converted in file")) + val file = errMsg.substring("Parquet column cannot be converted in file ".length, + errMsg.indexOf(". ")) + val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + assert(col.length == 1) + if (col(0).dataType == StringType) { + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + } else { + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + } + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From d766ea2ff2bf59afbd631d3cc2e43bebfccdebed Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 7 Apr 2018 00:15:54 +0800 Subject: [PATCH 265/593] [SPARK-23861][SQL][DOC] Clarify default window frame with and without orderBy clause ## What changes were proposed in this pull request? Add docstring to clarify default window frame boundaries with and without orderBy clause ## How was this patch tested? Manually generate doc and check. Author: Li Jin Closes #20978 from icexelloss/SPARK-23861-window-doc. --- python/pyspark/sql/window.py | 4 ++++ .../main/scala/org/apache/spark/sql/expressions/Window.scala | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index e667fba099fb9..d19ced954f04e 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -44,6 +44,10 @@ class Window(object): >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + .. note:: When ordering is not defined, an unbounded window frame (rowFrame, + unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined, + a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default. + .. note:: Experimental .. versionadded:: 1.4 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 1caa243f8d118..cd819bab1b14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.catalyst.expressions._ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) * }}} * + * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, + * unboundedFollowing) is used by default. When ordering is defined, a growing window frame + * (rangeFrame, unboundedPreceding, currentRow) is used by default. + * * @since 1.4.0 */ @InterfaceStability.Stable From c926acf719a6deb9d884a0f19bde075c312bfe5a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 18:42:14 +0200 Subject: [PATCH 266/593] [SPARK-23882][CORE] UTF8StringSuite.writeToOutputStreamUnderflow() is not expected to be supported ## What changes were proposed in this pull request? This PR excludes an existing UT [`writeToOutputStreamUnderflow()`](https://github.com/apache/spark/blob/master/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java#L519-L532) in `UTF8StringSuite`. As discussed [here](https://github.com/apache/spark/pull/19222#discussion_r177692142), the behavior of this test looks surprising. This test seems to access metadata area of the JVM object where is reserved by `Platform.BYTE_ARRAY_OFFSET`. This test is introduced thru #16089 by NathanHowell. More specifically, [the commit](https://github.com/apache/spark/pull/16089/commits/27c102deb1701fe62f776fe4da61dac959270b73) `Improve test coverage of UTFString.write` introduced this UT. However, I cannot find any discussion about this UT. I think that it would be good to exclude this UT. ```java public void writeToOutputStreamUnderflow() throws IOException { // offset underflow is apparently supported? final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { new UTF8String( new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); outputStream.reset(); } } ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20995 from kiszk/SPARK-23882. --- .../spark/unsafe/types/UTF8StringSuite.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bad908fcaf136..652c40a35527f 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -515,22 +515,6 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } - @Test - public void writeToOutputStreamUnderflow() throws IOException { - // offset underflow is apparently supported? - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); - - for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - new UTF8String( - new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) - .writeTo(outputStream); - final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); - assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); - outputStream.reset(); - } - } - @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); From d23a805f975f209f273db2b52de3f336be17d873 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Fri, 6 Apr 2018 10:09:55 -0700 Subject: [PATCH 267/593] [SPARK-23859][ML] Initial PR for Instrumentation improvements: UUID and logging levels ## What changes were proposed in this pull request? Initial PR for Instrumentation improvements: UUID and logging levels. This PR takes over #20837 Closes #20837 ## How was this patch tested? Manual. Author: Bago Amirbekian Author: WeichenXu Closes #20982 from WeichenXu123/better-instrumentation. --- .../classification/LogisticRegression.scala | 15 ++++--- .../spark/ml/util/Instrumentation.scala | 40 +++++++++++++++---- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3ae4db3f3f965..ee4b01058c75c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } val isConstantLabel = histogram.count(_ != 0.0) == 1 if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { - logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + - s"will be zeros. Training is not needed.") + instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " + + s"coefficients will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") ( (coefMatrix, interceptVec, Array.empty[Double]) } else { if (!$(fitIntercept) && isConstantLabel) { - logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + + instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + s"dangerous ground, so the algorithm may not converge.") } @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") ( if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant " + "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") ( (_initialModel.interceptVector.size == numCoefficientSets) && (_initialModel.getFitIntercept == $(fitIntercept)) if (!modelIsValid) { - logWarning(s"Initial coefficients will be ignored! Its dimensions " + + instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " + s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " + s"expected size ($numCoefficientSets, $numFeatures)") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 7c46f45c59717..e694bc27b2f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import java.util.concurrent.atomic.AtomicLong +import java.util.UUID import org.json4s._ import org.json4s.JsonDSL._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset private[spark] class Instrumentation[E <: Estimator[_]] private ( estimator: E, dataset: RDD[_]) extends Logging { - private val id = Instrumentation.counter.incrementAndGet() + private val id = UUID.randomUUID() private val prefix = { val className = estimator.getClass.getSimpleName s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " @@ -56,12 +56,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } /** - * Logs a message with a prefix that uniquely identifies the training session. + * Logs a warning message with a prefix that uniquely identifies the training session. */ - def log(msg: String): Unit = { - logInfo(prefix + msg) + override def logWarning(msg: => String): Unit = { + super.logWarning(prefix + msg) } + /** + * Logs a error message with a prefix that uniquely identifies the training session. + */ + override def logError(msg: => String): Unit = { + super.logError(prefix + msg) + } + + /** + * Logs an info message with a prefix that uniquely identifies the training session. + */ + override def logInfo(msg: => String): Unit = { + super.logInfo(prefix + msg) + } + + /** + * Alias for logInfo, see above. + */ + def log(msg: String): Unit = logInfo(msg) + /** * Logs the value of the given parameters for the estimator being used in this session. */ @@ -77,11 +96,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } def logNumFeatures(num: Long): Unit = { - log(compact(render("numFeatures" -> num))) + logNamedValue(Instrumentation.loggerTags.numFeatures, num) } def logNumClasses(num: Long): Unit = { - log(compact(render("numClasses" -> num))) + logNamedValue(Instrumentation.loggerTags.numClasses, num) } /** @@ -107,7 +126,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Some common methods for logging information about a training session. */ private[spark] object Instrumentation { - private val counter = new AtomicLong(0) + + object loggerTags { + val numFeatures = "numFeatures" + val numClasses = "numClasses" + val numExamples = "numExamples" + } /** * Creates an instrumentation object for a training session. From b6935ffb4dfb1d9fdf36ba402ac07bd02978c012 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:23:26 -0700 Subject: [PATCH 268/593] [SPARK-10399][SPARK-23879][HOTFIX] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following errors in [Java lint](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-lint/7717/console) after #19222 has been merged. These errors were pointed by ueshin . ``` [ERROR] src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java:[57] (sizes) LineLength: Line is longer than 100 characters (found 106). [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[26,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java:[23,10] (modifier) ModifierOrder: 'public' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[64,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[69,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[74,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[79,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[84,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[89,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[94,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[99,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[104,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[109,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[114,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[119,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[124,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[129,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[60,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[65,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[70,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[75,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[80,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[85,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[90,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[95,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[100,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[105,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[110,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[115,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[120,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[125,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java:[114,16] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[30,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.memory.MemoryBlock. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[126,15] (naming) MethodName: Method name 'ByteArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[143,15] (naming) MethodName: Method name 'OnHeapMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[160,15] (naming) MethodName: Method name 'OffHeapArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[19,8] (imports) UnusedImports: Unused import - com.google.common.primitives.Ints. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20991 from kiszk/SPARK-10399-jlint. --- .../sql/catalyst/expressions/HiveHasher.java | 1 - .../spark/unsafe/array/ByteArrayMethods.java | 4 +-- .../unsafe/memory/ByteArrayMemoryBlock.java | 28 +++++++++---------- .../unsafe/memory/HeapMemoryAllocator.java | 2 -- .../spark/unsafe/memory/MemoryBlock.java | 2 +- .../unsafe/memory/OffHeapMemoryBlock.java | 2 +- .../unsafe/memory/OnHeapMemoryBlock.java | 28 +++++++++---------- .../spark/unsafe/memory/MemoryBlockSuite.java | 6 ++-- .../spark/unsafe/types/UTF8StringSuite.java | 1 - .../spark/sql/catalyst/expressions/XXH64.java | 3 -- .../catalyst/expressions/HiveHasherSuite.java | 1 - 11 files changed, 35 insertions(+), 43 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 5d905943a3aa7..c34e36903a93e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index c334c9651cf6b..4bc9955090fd7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -54,7 +54,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEqualsBlock( - MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); } @@ -64,7 +64,7 @@ public static boolean arrayEqualsBlock( * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { int i = 0; // check if starts align and we can get both offsets to be aligned diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java index 99a9868a49a79..9f238632bc87a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -57,72 +57,72 @@ public static ByteArrayMemoryBlock fromArray(final byte[] array) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index acf28fd7ee59b..36caf80888cda 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,8 +23,6 @@ import java.util.LinkedList; import java.util.Map; -import org.apache.spark.unsafe.Platform; - /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index b086941108522..ca7213bbf92da 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -111,7 +111,7 @@ public final void fill(byte value) { /** * Instantiate MemoryBlock for given object type with new offset */ - public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { MemoryBlock mb = null; if (obj instanceof byte[]) { byte[] array = (byte[])obj; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java index f90f62bf21dcb..3431b08980eb8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; public class OffHeapMemoryBlock extends MemoryBlock { - static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); public OffHeapMemoryBlock(long address, long size) { super(null, address, size); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java index 12f67c7bd593e..ee42bc27c9c5f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -61,72 +61,72 @@ public static OnHeapMemoryBlock fromArray(final long[] array, long size) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return Platform.getByte(array, this.offset + offset); } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { Platform.putByte(array, this.offset + offset, value); } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 47f05c928f2e5..5d5fdc1c55a75 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -123,7 +123,7 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } @Test - public void ByteArrayMemoryBlockTest() { + public void testByteArrayMemoryBlock() { byte[] obj = new byte[56]; long offset = Platform.BYTE_ARRAY_OFFSET; int length = obj.length; @@ -140,7 +140,7 @@ public void ByteArrayMemoryBlockTest() { } @Test - public void OnHeapMemoryBlockTest() { + public void testOnHeapMemoryBlock() { long[] obj = new long[7]; long offset = Platform.LONG_ARRAY_OFFSET; int length = obj.length * 8; @@ -157,7 +157,7 @@ public void OnHeapMemoryBlockTest() { } @Test - public void OffHeapArrayMemoryBlockTest() { + public void testOffHeapArrayMemoryBlock() { MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); MemoryBlock memory = memoryAllocator.allocate(56); Object obj = memory.getBaseObject(); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 652c40a35527f..2c08535a16465 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -27,7 +27,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index 883748932ad33..fe727f6011cbf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,9 +16,6 @@ */ package org.apache.spark.sql.catalyst.expressions; -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 8ffc1d7c24d61..76930f9368514 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; From e998250588de0df250e2800278da4d3e3705c259 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 6 Apr 2018 11:51:36 -0700 Subject: [PATCH 269/593] [SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels ## What changes were proposed in this pull request? The Scala StringIndexerModel has an alternate constructor that will create the model from an array of label strings. Add the corresponding Python API: model = StringIndexerModel.from_labels(["a", "b", "c"]) ## How was this patch tested? Add doctest and unit test. Author: Huaxin Gao Closes #20968 from huaxingao/spark-23828. --- python/pyspark/ml/feature.py | 88 ++++++++++++++++++++++++++---------- python/pyspark/ml/tests.py | 41 ++++++++++++++++- 2 files changed, 104 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fcb0dfc563720..5a3e0dd655150 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2342,9 +2342,38 @@ def mean(self): return self._call_java("mean") +class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`. + """ + + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "or NULL values) in features and label column of string type. " + + "Options are 'skip' (filter out rows with invalid data), " + + "error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + + def __init__(self, *args): + super(_StringIndexerParams, self).__init__(*args) + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + + @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] + >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"], + ... inputCol="label", outputCol="indexed", handleInvalid="error") + >>> result = fromlabelsModel.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)] .. versionadded:: 1.4.0 """ - stringOrderType = Param(Params._dummy(), "stringOrderType", - "How to order labels of string column. The first label after " + - "ordering is assigned an index of 0. Supported options: " + - "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", - typeConverter=TypeConverters.toString) - - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + - "or NULL values) in features and label column of string type. " + - "Options are 'skip' (filter out rows with invalid data), " + - "error (throw an error), or 'keep' (put invalid data " + - "in a special additional bucket, at index numLabels).", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2414,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2440,21 +2461,33 @@ def setStringOrderType(self, value): """ return self._set(stringOrderType=value) - @since("2.3.0") - def getStringOrderType(self): - """ - Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. - """ - return self.getOrDefault(self.stringOrderType) - -class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): +class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """ + @classmethod + @since("2.4.0") + def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None): + """ + Construct the model directly from an array of label strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jlabels = StringIndexerModel._new_java_array(labels, java_class) + model = StringIndexerModel._create_from_java_class( + "org.apache.spark.ml.feature.StringIndexerModel", jlabels) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if handleInvalid is not None: + model.setHandleInvalid(handleInvalid) + return model + @property @since("1.5.0") def labels(self): @@ -2463,6 +2496,13 @@ def labels(self): """ return self._call_java("labels") + @since("2.4.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c2c4861e2aff4..4ce54547eab09 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -800,6 +800,43 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + class HasInducedError(Params): @@ -2097,9 +2134,11 @@ def test_java_params(self): ParamTests.check_params(self, cls(), check_params_exist=False) # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), check_params_exist=False) + ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) def _squared_distance(a, b): From 6ab134ca7d8f7802a6d196929513cc02b9b4d35d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 6 Apr 2018 15:00:13 -0700 Subject: [PATCH 270/593] [SPARK-21898][ML][FOLLOWUP] Fix Scala 2.12 build. ## What changes were proposed in this pull request? This is a follow-up pr of #19108 which broke Scala 2.12 build. ``` [error] /Users/ueshin/workspace/apache-spark/spark/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala:86: overloaded method value test with alternatives: [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: org.apache.spark.api.java.function.Function[java.lang.Double,java.lang.Double])org.apache.spark.sql.DataFrame [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: scala.Double => scala.Double)org.apache.spark.sql.DataFrame [error] cannot be applied to (org.apache.spark.sql.DataFrame, String, scala.Double => java.lang.Double) [error] test(dataset, sampleCol, (x: Double) => cdf.call(x)) [error] ^ [error] one error found ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20994 from ueshin/issues/SPARK-21898/fix_scala-2.12. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index 8d80e7768cb6e..c62d7463288f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -83,7 +83,8 @@ object KolmogorovSmirnovTest { @Since("2.4.0") def test(dataset: DataFrame, sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + val f: Double => Double = x => cdf.call(x) + test(dataset, sampleCol, f) } /** From 2c1fe647575e97e28b2232478ca86847d113e185 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 8 Apr 2018 12:09:06 +0800 Subject: [PATCH 271/593] [SPARK-23847][PYTHON][SQL] Add asc_nulls_first, asc_nulls_last to PySpark ## What changes were proposed in this pull request? Column.scala and Functions.scala have asc_nulls_first, asc_nulls_last, desc_nulls_first and desc_nulls_last. Add the corresponding python APIs in column.py and functions.py ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #20962 from huaxingao/spark-23847. --- python/pyspark/sql/column.py | 56 +++++++++++++++++-- python/pyspark/sql/functions.py | 13 +++++ python/pyspark/sql/tests.py | 14 +++++ .../scala/org/apache/spark/sql/Column.scala | 4 +- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 82 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 922c7cf288f8f..e7dec11c69b57 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -447,24 +447,72 @@ def isin(self, *cols): # order _asc_doc = """ - Returns a sort expression based on the ascending order of the given column name + Returns a sort expression based on ascending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc()).collect() [Row(name=u'Alice'), Row(name=u'Tom')] """ + _asc_nulls_first_doc = """ + Returns a sort expression based on ascending order of the column, and null values + return before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + + .. versionadded:: 2.4 + """ + _asc_nulls_last_doc = """ + Returns a sort expression based on ascending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + + .. versionadded:: 2.4 + """ _desc_doc = """ - Returns a sort expression based on the descending order of the given column name. + Returns a sort expression based on the descending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc()).collect() [Row(name=u'Tom'), Row(name=u'Alice')] """ + _desc_nulls_first_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + + .. versionadded:: 2.4 + """ + _desc_nulls_last_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + + .. versionadded:: 2.4 + """ asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc)) + asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc)) desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) + desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc)) + desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc)) _isNull_doc = """ True if the current expression is null. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad3e37c872628..1b192680f0795 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -138,6 +138,17 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } +_functions_2_4 = { + 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values return before non-null values.', + 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values appear after non-null values.', + 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear before non-null values.', + 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear after non-null values', +} + _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. @@ -250,6 +261,8 @@ def _(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5181053a0d318..dd04ffb4ed393 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,6 +2991,20 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() + def test_2_4_functions(self): + from pyspark.sql import functions + + df = self.spark.createDataFrame( + [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 92988680871a4..ad0efbae89830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1083,10 +1083,10 @@ class Column(val expr: Expression) extends Logging { * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. - * df.sort(df("age").asc_nulls_last) + * df.sort(df("age").asc_nulls_first) * * // Java - * df.sort(df.col("age").asc_nulls_last()); + * df.sort(df.col("age").asc_nulls_first()); * }}} * * @group expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9ca9a8996344..c658f25ced053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -132,7 +132,7 @@ object functions { * Returns a sort expression based on ascending order of the column, * and null values return before non-null values. * {{{ - * df.sort(asc_nulls_last("dept"), desc("age")) + * df.sort(asc_nulls_first("dept"), desc("age")) * }}} * * @group sort_funcs From 6a734575a80e6b4ec4963206254451f05d64b742 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 7 Apr 2018 21:44:32 -0700 Subject: [PATCH 272/593] [SPARK-23849][SQL] Tests for the samplingRatio option of JSON datasource ## What changes were proposed in this pull request? Proposed tests checks that only subset of input dataset is touched during schema inferring. Author: Maxim Gekk Closes #20963 from MaxGekk/json-sampling-tests. --- .../datasources/json/JsonSuite.scala | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 10bac0554484a..70aee561ff0f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -2127,4 +2127,39 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(df.schema === expectedSchema) } } + + test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + withTempPath { path => + val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), + StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) + for (i <- 0 until 100) { + if (predefinedSample.contains(i)) { + writer.write(s"""{"f1":${i.toString}}""" + "\n") + } else { + writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") + } + } + writer.close() + + val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + assert(ds.schema == new StructType().add("f1", LongType)) + } + } + + test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { + val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(i)) { + s"""{"f1":${i.toString}}""" + "\n" + } else { + s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" + } + }.toDS() + val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + + assert(ds.schema == new StructType().add("f1", LongType)) + } } From 710a68cec27a94c2df10d8b4022a755a94a5443b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:26:31 +0200 Subject: [PATCH 273/593] [SPARK-23892][TEST] Improve converge and fix lint error in UTF8String-related tests ## What changes were proposed in this pull request? This PR improves test coverage in `UTF8StringSuite` and code efficiency in `UTF8StringPropertyCheckSuite`. This PR also fixes lint-java issue in `UTF8StringSuite` reported at [here](https://github.com/apache/spark/pull/20995#issuecomment-379325527) ```[ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[28,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform.``` ## How was this patch tested? Existing UT Author: Kazuaki Ishizaki Closes #21000 from kiszk/SPARK-23892. --- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 5 ++--- .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 2c08535a16465..42dda30480702 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,7 +25,6 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; @@ -53,8 +52,8 @@ private static void checkBasic(String str, int len) { assertTrue(s1.contains(s2)); assertTrue(s2.contains(s1)); - assertTrue(s1.startsWith(s1)); - assertTrue(s1.endsWith(s1)); + assertTrue(s1.startsWith(s2)); + assertTrue(s1.endsWith(s2)); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 62d4176d00f94..48004e812a8bf 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -164,7 +164,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { if (length <= 0) return "" if (length <= origin.length) { - if (length <= 0) "" else origin.substring(0, length) + origin.substring(0, length) } else { if (pad.length == 0) return origin val toPad = length - origin.length From 8d40a79a077a30024a8ef921781b68f6f7e542d1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:40:27 +0200 Subject: [PATCH 274/593] [SPARK-23893][CORE][SQL] Avoid possible integer overflow in multiplication ## What changes were proposed in this pull request? This PR avoids possible overflow at an operation `long = (long)(int * int)`. The multiplication of large positive integer values may set one to MSB. This leads to a negative value in long while we expected a positive value (e.g. `0111_0000_0000_0000 * 0000_0000_0000_0010`). This PR performs long cast before the multiplication to avoid this situation. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21002 from kiszk/SPARK-23893. --- .../util/collection/unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../util/collection/unsafe/sort/UnsafeSortDataFormat.java | 2 +- .../src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../scala/org/apache/spark/InternalAccumulatorSuite.scala | 2 +- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- .../test/scala/org/apache/spark/util/JsonProtocolSuite.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/HashBenchmark.scala | 2 +- .../scala/org/apache/spark/sql/HashByteArrayBenchmark.scala | 3 ++- .../org/apache/spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../columnar/compression/CompressionSchemeBenchmark.scala | 4 ++-- .../sql/execution/vectorized/ColumnarBatchBenchmark.scala | 2 +- 11 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 20a7a8b267438..717823ebbd320 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -124,7 +124,7 @@ public UnsafeInMemorySorter( int initialSize, boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2), canUseRadixSort); + consumer.allocateArray(initialSize * 2L), canUseRadixSort); } public UnsafeInMemorySorter( diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d9f84d10e9051..37772f41caa87 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -84,7 +84,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int @Override public LongArray allocate(int length) { - assert (length * 2 <= buffer.size()) : + assert (length * 2L <= buffer.size()) : "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2); return buffer; } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index c9ed12f4e1bd4..13db4985b0b80 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -90,7 +90,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = partsScanned * 4 + numPartsToTry = partsScanned * 4L } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 8d7be77f51fe9..62824a5bec9d1 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt // ID that's < 2 * numPartitions belongs to the first attempt of this stage. val taskContext = TaskContext.get() - val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L if (isFirstStageAttempt) { throw new FetchFailedException( SparkEnv.get.blockManager.blockManagerId, diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fde5f25bce456..0ba57bf4563c1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -382,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) writeFile(log, true, None, SparkListenerApplicationStart( - "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), - SparkListenerApplicationEnd(5001 * i) + "downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001L * i) ) log } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 4abbb8e7894f5..74b72d940eeef 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("SparkListenerJobStart backward compatibility") { // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) @@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite { // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. // Also, SparkListenerJobEnd did not have a "Completion Time" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) .removeField({ _._1 == "Submission Time"}) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 2d94b66a1e122..9a89e6290e695 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -40,7 +40,7 @@ object HashBenchmark { safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() ).toArray - val benchmark = new Benchmark("Hash For " + name, iters * numRows) + val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong) benchmark.addCase("interpreted version") { _: Int => var sum = 0 for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 2a753a0c84ed5..f6c8111f5bc57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -36,7 +36,8 @@ object HashByteArrayBenchmark { bytes } - val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + val benchmark = + new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong) benchmark.addCase("Murmur3_x86_32") { _: Int => var sum = 0L for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 769addf3b29e6..6c63769945312 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark { val iters = 1024 * 16 val numRows = 1024 * 16 - val benchmark = new Benchmark("unsafe projection", iters * numRows) + val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong) val schema1 = new StructType().add("l", LongType, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 9005ec93e786e..619b76fabdd5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -77,7 +77,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -101,7 +101,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 1f31aa45a1220..8aeb06d428951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -295,7 +295,7 @@ object ColumnarBatchBenchmark { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 - val benchmark = new Benchmark("Boolean Read/Write", iters * count) + val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong) benchmark.addCase("Bitset") { i: Int => { val b = new BitSet(count) var sum = 0L From 32471ba0af52b59141b44a8375025b6a7eafae70 Mon Sep 17 00:00:00 2001 From: Nolan Emirot Date: Mon, 9 Apr 2018 08:04:02 -0500 Subject: [PATCH 275/593] Fix typo in Python docstring kinesis example ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nolan Emirot Closes #20990 from emirot/kinesis_stream_example_typo. --- .../src/main/python/examples/streaming/kinesis_wordcount_asl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index 4d7fc9a549bfb..49794faab88c4 100644 --- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -34,7 +34,7 @@ $ export AWS_SECRET_KEY= # run the example - $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\ + $ bin/spark-submit -jars external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com From d81f29ecafe8fc9816e36087e3b8acdc93d6cc1b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 9 Apr 2018 10:19:22 -0700 Subject: [PATCH 276/593] [SPARK-23881][CORE][TEST] Fix flaky test JobCancellationSuite."interruptible iterator of shuffle reader" ## What changes were proposed in this pull request? The test case JobCancellationSuite."interruptible iterator of shuffle reader" has been flaky because `KillTask` event is handled asynchronously, so it can happen that the semaphore is released but the task is still running. Actually we only have to check if the total number of processed elements is less than the input elements number, so we know the task get cancelled. ## How was this patch tested? The new test case still fails without the purposed patch, and succeeded in current master. Author: Xingbo Jiang Closes #20993 from jiangxb1987/JobCancellationSuite. --- .../apache/spark/JobCancellationSuite.scala | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 3b793bb231cf3..61da4138896cd 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -332,13 +332,15 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft import JobCancellationSuite._ sc = new SparkContext("local[2]", "test interruptible iterator") + // Increase the number of elements to be proceeded to avoid this test being flaky. + val numElements = 10000 val taskCompletedSem = new Semaphore(0) sc.addSparkListener(new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { // release taskCancelledSemaphore when cancelTasks event has been posted if (stageCompleted.stageInfo.stageId == 1) { - taskCancelledSemaphore.release(1000) + taskCancelledSemaphore.release(numElements) } } @@ -349,28 +351,31 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) - val f = sc.parallelize(1 to 1000).map { i => (i, i) } + // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be + // interrupted by `InterruptibleIterator`. + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + + val f = sc.parallelize(1 to numElements).map { i => (i, i) } .repartitionAndSortWithinPartitions(new HashPartitioner(1)) .mapPartitions { iter => taskStartedSemaphore.release() iter }.foreachAsync { x => - if (x._1 >= 10) { - // This block of code is partially executed. It will be blocked when x._1 >= 10 and the - // next iteration will be cancelled if the source iterator is interruptible. Then in this - // case, the maximum num of increment would be 10(|1...10|) - taskCancelledSemaphore.acquire() - } + // Block this code from being executed, until the job get cancelled. In this case, if the + // source iterator is interruptible, the max number of increment should be under + // `numElements`. + taskCancelledSemaphore.acquire() executionOfInterruptibleCounter.getAndIncrement() } taskStartedSemaphore.acquire() // Job is cancelled when: // 1. task in reduce stage has been started, guaranteed by previous line. - // 2. task in reduce stage is blocked after processing at most 10 records as - // taskCancelledSemaphore is not released until cancelTasks event is posted - // After job being cancelled, task in reduce stage will be cancelled and no more iteration are - // executed. + // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until + // JobCancelled event is posted. + // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus + // partial of the inputs should not get processed (It's very unlikely that Spark can process + // 10000 elements between JobCancelled is posted and task is really killed). f.cancel() val e = intercept[SparkException](f.get()).getCause @@ -378,7 +383,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Make sure tasks are indeed completed. taskCompletedSem.acquire() - assert(executionOfInterruptibleCounter.get() <= 10) + assert(executionOfInterruptibleCounter.get() < numElements) } def testCount() { From 10f45bb8233e6ac838dd4f053052c8556f5b54bd Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 9 Apr 2018 11:31:21 -0700 Subject: [PATCH 277/593] [SPARK-23816][CORE] Killed tasks should ignore FetchFailures. SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins. Author: Imran Rashid Closes #20987 from squito/SPARK-23816. --- .../org/apache/spark/executor/Executor.scala | 26 +++--- .../apache/spark/executor/ExecutorSuite.scala | 92 +++++++++++++++---- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dcec3ec21b546..c325222b764b8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -480,6 +480,19 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { @@ -494,19 +507,6 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - - case _: InterruptedException | NonFatal(_) if - task != null && task.reasonIfKilled.isDefined => - val killReason = task.reasonIfKilled.getOrElse("unknown reason") - logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) - case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 105a178f2d94e..1a7bebe2c53cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) + } + + test("SPARK-23816: interrupts are not masked by a FetchFailure") { + // If killing the task causes a fetch failure, we still treat it as a task that was killed, + // as the fetch failure could easily be caused by interrupting the thread. + val (failReason, _) = testFetchFailureHandling(false) + assert(failReason.isInstanceOf[TaskKilled]) + } + + /** + * Helper for testing some cases where a FetchFailure should *not* get sent back, because its + * superceded by another error, either an OOM or intentionally killing a task. + * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the + * FetchFailure + */ + private def testFetchFailureHandling( + oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. + // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task + // does not represent a real fetch failure. val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat - // the fetch failure as a false positive, and just do normal OOM handling. + // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We + // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + if (!oom) { + // we are trying to setup a case where a task is killed after a fetch failure -- this + // is just a helper to coordinate between the task thread and this thread that will + // kill the task + ExecutorSuiteHelper.latches = new ExecutorSuiteHelper() + } + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serTask = serializer.serialize(task) val taskDescription = createFakeTaskDescription(serTask) - val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) - // make sure the task failure just looks like a OOM, not a fetch failure - assert(failReason.isInstanceOf[ExceptionFailure]) - val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) - verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) - assert(exceptionCaptor.getAllValues.size === 1) - assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) - } + runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom) + } test("Gracefully handle error in task deserialization") { val conf = new SparkConf @@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null + val timedOut = new AtomicBoolean(false) try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) + if (killTask) { + val killingThread = new Thread("kill-task") { + override def run(): Unit = { + // wait to kill the task until it has thrown a fetch failure + if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) { + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } else { + timedOut.set(true) + } + } + } + killingThread.start() + } eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } + assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop() @@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend) .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture()) // first statusUpdate for RUNNING has empty data assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting @@ -321,7 +364,8 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, val input: FetchFailureThrowingRDD, - throwOOM: Boolean) extends RDD[Int](input) { + throwOOM: Boolean, + interrupt: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { @@ -330,6 +374,15 @@ class FetchFailureHidingRDD( case t: Throwable => if (throwOOM) { throw new OutOfMemoryError("OOM while handling another exception") + } else if (interrupt) { + // make sure our test is setup correctly + assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) + // signal our test is ready for the task to get killed + ExecutorSuiteHelper.latches.latch1.countDown() + // then wait for another thread in the test to kill the task -- this latch + // is never actually decremented, we just wait to get killed. + ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS) + throw new IllegalStateException("timed out waiting to be interrupted") } else { throw new RuntimeException("User Exception that hides the original exception", t) } @@ -352,6 +405,11 @@ private class ExecutorSuiteHelper { @volatile var testFailedReason: TaskFailedReason = _ } +// helper for coordinating killing tasks +private object ExecutorSuiteHelper { + var latches: ExecutorSuiteHelper = null +} + private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { def writeExternal(out: ObjectOutput): Unit = {} def readExternal(in: ObjectInput): Unit = { From 7c1654e2159662e7e663ba141719d755002f770a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Apr 2018 11:54:35 -0700 Subject: [PATCH 278/593] [SPARK-22856][SQL] Add wrappers for codegen output and nullability ## What changes were proposed in this pull request? The codegen output of `Expression`, aka `ExprCode`, now encapsulates only strings of output value (`value`) and nullability (`isNull`). It makes difficulty for us to know what the output really is. I think it is better if we can add wrappers for the value and nullability that let us to easily know that. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20043 from viirya/SPARK-22856. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 16 ++-- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 16 ++-- .../expressions/codegen/CodegenFallback.scala | 2 +- .../expressions/codegen/ExprValue.scala | 76 +++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 19 +++-- .../codegen/GenerateUnsafeProjection.scala | 8 +- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeCreator.scala | 8 +- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/expressions/hash.scala | 4 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 24 +++--- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 25 ++++-- .../expressions/objects/objects.scala | 28 ++++--- .../sql/catalyst/expressions/predicates.scala | 10 +-- .../expressions/randomExpressions.scala | 6 +- .../expressions/CodeGenerationSuite.scala | 6 +- .../expressions/codegen/ExprValueSuite.scala | 46 +++++++++++ .../sql/execution/ColumnarBatchScan.scala | 10 ++- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 13 ++-- .../sql/execution/WholeStageCodegenExec.scala | 13 ++-- .../aggregate/HashAggregateExec.scala | 3 +- .../aggregate/HashMapGenerator.scala | 5 +- .../execution/basicPhysicalOperators.scala | 6 +- .../joins/BroadcastHashJoinExec.scala | 6 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- 35 files changed, 294 insertions(+), 120 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 89ffbb0016916..5021a567592e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -76,7 +76,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 38caf67d465d8..7a5e49cb5206b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode("", + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,10 +120,10 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = VariableValue(newValue, javaType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -446,7 +448,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -546,7 +548,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -690,7 +692,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index dd523d312e3b4..ad1e7bdb31987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index cc6a769d032d3..787bcaf5e81de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", - isNull = "false") + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 508bdd5050b54..478ff3a7c1011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -601,7 +601,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -680,7 +681,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84b1e3fbda876..c9c60ef1be640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,16 +56,17 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { def forNullValue(dataType: DataType): ExprCode = { val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode(code = "", isNull = TrueLiteral, + value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) } - def forNonNullValue(value: String): ExprCode = { - ExprCode(code = "", isNull = "false", value = value) + def forNonNullValue(value: ExprValue): ExprCode = { + ExprCode(code = "", isNull = FalseLiteral, value = value) } } @@ -77,7 +78,7 @@ object ExprCode { * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -330,7 +331,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) } def declareMutableStates(): String = { @@ -1003,7 +1004,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), + GlobalValue(value, javaType(expr.dataType))) subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index e12420bb5dfdd..a91989e129664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -59,7 +59,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala new file mode 100644 index 0000000000000..df5f1c58b1b2d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.implicitConversions + +import org.apache.spark.sql.types.DataType + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue { + + val javaType: String + + // Whether we can directly access the evaluation value anywhere. + // For example, a variable created outside a method can not be accessed inside the method. + // For such cases, we may need to pass the evaluation as parameter. + val canDirectAccess: Boolean + + def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +class LiteralValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +object LiteralValue { + def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, String)] = + Some((literal.value, literal.javaType)) +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue( + val variableName: String, + val javaType: String) extends ExprValue { + override def toString: String = variableName + override val canDirectAccess: Boolean = false +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue( + val statement: String, + val javaType: String, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +case object TrueLiteral extends LiteralValue("true", "boolean") +case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d35fd8ecb4d63..3ae0b54c754cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { + val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) + val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f92f70ee71fef..a30a0b22cd305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +76,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) } private def createCodeForArray( @@ -89,8 +91,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe( - ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) + val elementConverter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), + CodeGenerator.javaType(elementType)), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +107,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) } private def createCodeForMap( @@ -125,19 +128,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode("", FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index ab2254cd9f70a..4a4d76313a543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt))) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -334,7 +336,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $evalSubexpr $writeExpressions """ - ExprCode(code, "false", s"$rowWriter.getRow()") + // `rowWriter` is declared as a class field, so we can access it directly in methods. + ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", + canDirectAccess = true)) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e8..91188da8b0bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 85facdad43db7..49a8d12057188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f4e9619bac59d..409c0b6b79b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,7 +191,8 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) + ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + CodeGenerator.javaType(dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1ae4e5a2f716b..49dd988b4b53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,7 +813,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", + CodeGenerator.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4f4d49166e88c..3af4bfebad45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b76b64ab5096f..df29c38d64d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -644,7 +644,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 07785e7448586..2a3cc580273ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 7395609a04ba5..742a650eb445d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -283,36 +283,36 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(value.toString) + ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue("Float.NaN") + ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}F") + ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue("Double.NaN") + ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}D") + ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(s"($javaType)$value") + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) case TimestampType | LongType => - ExprCode.forNonNullValue(s"${value}L") + ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(constRef) + ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a390f8ef7fd9a..7081a5e096d56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -91,7 +91,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = TrueLiteral, + value = LiteralValue("null", CodeGenerator.javaType(dataType))) } override def sql: String = s"assert_true(${child.sql})" @@ -150,7 +151,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Uuid = Uuid(randomSeed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b35fa72e95d1e..55b6e346be82a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -235,7 +236,7 @@ case class IsNaN(child: Expression) extends UnaryExpression ev.copy(code = s""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } } @@ -320,7 +321,12 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + val value = if (eval.isNull.isInstanceOf[LiteralValue]) { + LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -346,7 +352,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = if (eval.isNull == TrueLiteral) { + FalseLiteral + } else if (eval.isNull == FalseLiteral) { + TrueLiteral + } else { + StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -441,6 +454,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9252425f86473..b2cca3178cd2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -61,13 +61,13 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) } else { - "false" + FalseLiteral } val argValues = arguments.map { e => val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") @@ -244,7 +244,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = FalseLiteral "" } @@ -546,7 +546,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -568,7 +568,13 @@ case class LambdaVariable( } override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + val isNullValue = if (nullable) { + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } + ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), + isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make @@ -840,7 +846,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -1343,7 +1349,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1538,7 +1544,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1589,7 +1595,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4b85d9adbe311..e195ec17f3bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -405,7 +405,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -461,7 +461,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -469,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -615,7 +615,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f36633867316e..70186053617f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -83,7 +83,7 @@ case class Rand(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Rand = Rand(child) @@ -120,7 +120,7 @@ case class Randn(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Randn = Randn(child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 398b6767654fa..8e83b35c3809c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,6 +448,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) + val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), + VariableValue("dummy", "boolean")) // raw testing of basic functionality { @@ -457,7 +459,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) assert(ctx.subExprEliminationExprs.contains(ref)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) Seq.empty @@ -475,7 +477,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE assert(ctx.subExprEliminationExprs.contains(add1)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { assert(ctx.subExprEliminationExprs.contains(ref)) assert(!ctx.subExprEliminationExprs.contains(add1)) Seq.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala new file mode 100644 index 0000000000000..c8f4cff7db48d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprValueSuite extends SparkFunSuite { + + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral + + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + assert(trueLit.isPrimitive) + assert(falseLit.isPrimitive) + + trueLit match { + case LiteralValue(value, javaType) => + assert(value == "true" && javaType == "boolean") + case _ => fail() + } + + falseLit match { + case LiteralValue(value, javaType) => + assert(value == "false" && javaType == "boolean") + case _ => fail() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 392906a022903..434214a10e1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -51,7 +51,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { nullable: Boolean): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val isNullVar = if (nullable) { + VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { @@ -62,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 12ae1ea4a7c13..0d9a62cace62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,7 +157,8 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 384f0398a1ec0..85c5ebfdaa689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -170,9 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), + VariableValue(index, CodeGenerator.JAVA_INT))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) } } else { Seq.empty @@ -315,9 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, javaType)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, + VariableValue(value, javaType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6ddaacfee1a40..805ff3cf001ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", "false", row) + ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -126,10 +126,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, "false", ev.value) + ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", "false", "unsafeRow") + ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) } } } @@ -241,15 +241,16 @@ trait CodegenSupport extends SparkPlan { parameters += s"$paramType $paramName" val paramIsNull = if (!attributes(i).nullable) { // Use constant `false` without passing `isNull` for non-nullable variable. - "false" + FalseLiteral } else { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - isNull + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) } - paramVars += ExprCode("", paramIsNull, paramName) + paramVars += ExprCode("", paramIsNull, + VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1926e9373bc55..8f7f10243d4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 6b60b414ffe5f..4978954271311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4707022f74547..cab7081400ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = FalseLiteral } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 487d6a2383318..fa62a32d51f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))) } } } @@ -487,7 +488,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode("", FalseLiteral, + VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5a511b30e4fd9..b61acb8d4fda9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -531,11 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, FalseLiteral, + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } }.unzip } From 252468a744b95082400ba9e8b2e3b3d9d50ab7fa Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 9 Apr 2018 12:18:07 -0700 Subject: [PATCH 279/593] [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ## What changes were proposed in this pull request? API: ``` trait ClassificationNode extends Node def getLabelCount(label: Int): Double trait RegressionNode extends Node def getCount(): Double def getSum(): Double def getSquareSum(): Double // turn LeafNode to be trait trait LeafNode extends Node { def prediction: Double def impurity: Double ... } class ClassificationLeafNode extends ClassificationNode with LeafNode class RegressionLeafNode extends RegressionNode with LeafNode // turn InternalNode to be trait trait InternalNode extends Node{ def gain: Double def leftChild: Node def rightChild: Node def split: Split ... } class ClassificationInternalNode extends ClassificationNode with InternalNode override def leftChild: ClassificationNode override def rightChild: ClassificationNode class RegressionInternalNode extends RegressionNode with InternalNode override val leftChild: RegressionNode override val rightChild: RegressionNode class DecisionTreeClassificationModel override val rootNode: ClassificationNode class DecisionTreeRegressionModel override val rootNode: RegressionNode ``` Closes #17466 ## How was this patch tested? UT will be added soon. Author: WeichenXu Author: jkbradley Closes #20786 from WeichenXu123/tree_stat_api_2. --- .../DecisionTreeClassifier.scala | 14 +- .../ml/classification/GBTClassifier.scala | 6 +- .../RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 247 ++++++++++++++---- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 ++- .../DecisionTreeClassifierSuite.scala | 31 ++- .../classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 14 + .../ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 9 +- 16 files changed, 333 insertions(+), 108 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 65cce697d8202..771cd4fe91dcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: Node, + @Since("1.4.0")override val rootNode: ClassificationNode, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) + val model = new DecisionTreeClassificationModel(metadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + new DecisionTreeClassificationModel(uid, + rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cd44489f618b2..c0255103bc313 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -371,14 +371,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 78a4972adbdbb..bb972e9706fc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,15 +310,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + val tree = new DecisionTreeClassificationModel(treeMetadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ad154fcd010cc..5cef5c9f21f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node, + override val rootNode: RegressionNode, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int) = + private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) + val model = new DecisionTreeRegressionModel(metadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 6569ff2a5bfc1..834aaa0e362d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -302,15 +302,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2d594460c2475..7f77398ba2a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -269,13 +269,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d30be452a436e..0242bc76698d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed abstract class Node extends Serializable { +sealed trait Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -84,35 +86,86 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + def fromOld( + oldNode: OldNode, + categoricalFeatures: Map[Int, Int], + isClassification: Boolean): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) + if (isClassification) { + new ClassificationLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + new RegressionLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, - gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + if (isClassification) { + new ClassificationInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } else { + new RegressionInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } } } } -/** - * Decision tree leaf node. - * @param prediction Prediction this node makes - * @param impurity Impurity measure at this node (for training data) - */ -class LeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { +@Since("2.4.0") +sealed trait ClassificationNode extends Node { + + /** + * Get count of training examples for specified label in this node + * @param label label number in the range [0, numClasses) + */ + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < impurityStats.stats.length, + "label should be in the range between 0 (inclusive) " + + s"and ${impurityStats.stats.length} (exclusive).") + impurityStats.stats(label) + } +} + +@Since("2.4.0") +sealed trait RegressionNode extends Node { + + /** Number of training data points in this node */ + @Since("2.4.0") + def getCount: Double = impurityStats.stats(0) + + /** Sum over training data points of the labels in this node */ + @Since("2.4.0") + def getSum: Double = impurityStats.stats(1) + + /** Sum over training data points of the square of the labels in this node */ + @Since("2.4.0") + def getSumOfSquares: Double = impurityStats.stats(2) +} + +@Since("2.4.0") +sealed trait LeafNode extends Node { + + /** Prediction this node makes. */ + def prediction: Double + + def impurity: Double override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -135,32 +188,58 @@ class LeafNode private[ml] ( override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * Decision tree leaf node for classification. + */ +@Since("2.4.0") +class ClassificationLeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with LeafNode { + override private[tree] def deepCopy(): Node = { - new LeafNode(prediction, impurity, impurityStats) + new ClassificationLeafNode(prediction, impurity, impurityStats) } } /** - * Internal Decision Tree node. - * @param prediction Prediction this node would make if it were a leaf node - * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - * @param leftChild Left-hand child node - * @param rightChild Right-hand child node - * @param split Information about the test used to split to the left or right child. + * Decision tree leaf node for regression. */ -class InternalNode private[ml] ( +@Since("2.4.0") +class RegressionLeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - val gain: Double, - val leftChild: Node, - val rightChild: Node, - val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with LeafNode { - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. + override private[tree] def deepCopy(): Node = { + new RegressionLeafNode(prediction, impurity, impurityStats) + } +} + +/** + * Internal Decision Tree node. + */ +@Since("2.4.0") +sealed trait InternalNode extends Node { + + /** + * Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + */ + def gain: Double + + /** Left-hand child node */ + def leftChild: Node + + /** Right-hand child node */ + def rightChild: Node + + /** Information about the test used to split to the left or right child. */ + def split: Split override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -205,11 +284,6 @@ class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } - - override private[tree] def deepCopy(): Node = { - new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), - split, impurityStats) - } } private object InternalNode { @@ -240,6 +314,57 @@ private object InternalNode { } } +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class ClassificationInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: ClassificationNode, + override val rightChild: ClassificationNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new ClassificationInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[ClassificationNode], + rightChild.deepCopy().asInstanceOf[ClassificationNode], + split, impurityStats) + } +} + +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class RegressionInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: RegressionNode, + override val rightChild: RegressionNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new RegressionInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[RegressionNode], + rightChild.deepCopy().asInstanceOf[RegressionNode], + split, impurityStats) + } +} + + /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -265,30 +390,52 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode: Node = toNode(prune = true) + def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) + + def toClassificationNode(prune: Boolean = true): ClassificationNode = { + toNode(true, prune).asInstanceOf[ClassificationNode] + } + + def toRegressionNode(prune: Boolean = true): RegressionNode = { + toNode(false, prune).asInstanceOf[RegressionNode] + } /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(prune: Boolean = true): Node = { + def toNode(isClassification: Boolean, prune: Boolean): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + (leftChild.get.toNode(isClassification, prune), + rightChild.get.toNode(isClassification, prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + if (isClassification) { + new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } else { + new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } case (l, r) => - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l, r, split.get, stats.impurityCalculator) + if (isClassification) { + new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, + stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], + split.get, stats.impurityCalculator) + } else { + new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], + split.get, stats.impurityCalculator) + } } } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + val impurity = if (stats.valid) stats.impurity else -1.0 + if (isClassification) { + new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, stats.impurityCalculator) } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + new RegressionLeafNode(stats.impurityCalculator.predict, impurity, + stats.impurityCalculator) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 16f32d76b9984..056a94b351f79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -224,23 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c3617e7fd..f027b14f1d476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => + case _: LeafNode => // do nothing + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession): Node = { + sparkSession: SparkSession, isClassification: Boolean): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, isClassification) } /** @@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + isClassification: Boolean): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, - n.split.getSplit, impurityStats) + if (isClassification) { + new ClassificationInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], + n.split.getSplit, impurityStats) + } else { + new RegressionInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], + n.split.getSplit, impurityStats) + } } else { - new LeafNode(n.prediction, n.impurity, impurityStats) + if (isClassification) { + new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) + } else { + new RegressionLeafNode(n.prediction, n.impurity, impurityStats) + } } finalNodes(n.id) = node } @@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String, + isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( + nodeData.toArray, impurityType, isClassification) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2930f4900d50e..d3dbb4e754d3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,7 +61,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -375,6 +376,32 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("entropy") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val rootNode1 = model1.rootNode + assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val rootNode2 = model2.rootNode + assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 57796069f6052..f0ee5496f9d1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.RegressionLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -69,7 +69,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ba4a9cf082785..3062aa9f3d274 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,7 +71,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 29a438396516b..9ae27339b11d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode + + assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 + && statInfo.getSumOfSquares == 600.0) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..4dbbd75d2466d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) - .asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], + numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..3f03d909d4a4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { + def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + if (isClassification) { + new ClassificationInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], + split, parentImp) + } else { + new RegressionInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], + split, parentImp) + } } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1b6d1dec69d49..b37b4d51775e8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,14 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") ) // Exclude rules for 2.3.x From 61b724724cc4a18818774ecaaa5a45b70fdb8dae Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Apr 2018 14:07:33 -0700 Subject: [PATCH 280/593] [INFRA] Close stale PRs. Closes #20957 Closes #20792 From f94f3624ea81053653a06560808cb71f510c6828 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 9 Apr 2018 21:07:28 -0700 Subject: [PATCH 281/593] [SPARK-23947][SQL] Add hashUTF8String convenience method to hasher classes ## What changes were proposed in this pull request? Add `hashUTF8String()` to the hasher classes to allow Spark SQL codegen to generate cleaner code for hashing `UTF8String`s. No change in behavior otherwise. Although with the introduction of SPARK-10399, the code size for hashing `UTF8String` is already smaller, it's still good to extract a separate function in the hasher classes so that the generated code can stay clean. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21016 from rednaxelafx/hashutf8. --- .../apache/spark/sql/catalyst/expressions/HiveHasher.java | 5 +++++ .../java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java | 7 ++++++- .../org/apache/spark/sql/catalyst/expressions/XXH64.java | 5 +++++ .../org/apache/spark/sql/catalyst/expressions/hash.scala | 6 ++---- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c34e36903a93e..62b75ae8aa01d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -51,4 +52,8 @@ public static int hashUnsafeBytesBlock(MemoryBlock mb) { public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); } + + public static int hashUTF8String(UTF8String str) { + return hashUnsafeBytesBlock(str.getMemoryBlock()); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index f372b19fac119..aff6e93d647fe 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -20,6 +20,7 @@ import com.google.common.primitives.Ints; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -82,6 +83,10 @@ public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { return fmix(h1, lengthInBytes); } + public static int hashUTF8String(UTF8String str, int seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } @@ -91,7 +96,7 @@ public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, } public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { - // This is compatible with original and another implementations. + // This is compatible with original and other implementations. // Use this method for new components after Spark 2.3. int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index fe727f6011cbf..8e9c0a2e9dc81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off /** @@ -107,6 +108,10 @@ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { return fmix(hash); } + public static long hashUTF8String(UTF8String str, long seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index df29c38d64d3d..ef790338bdd27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -361,8 +361,7 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" + s"$result = $hasherClassName.hashUTF8String($input, $result);" } protected def genHashForMap( @@ -725,8 +724,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" + s"$result = $hasherClassName.hashUTF8String($input);" } override protected def genHashForArray( From 64988841540464e261b0cbaede43058e7bd36261 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 9 Apr 2018 21:49:49 -0700 Subject: [PATCH 282/593] [SPARK-23898][SQL] Simplify add & subtract code generation ## What changes were proposed in this pull request? Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication. This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21005 from hvanhovell/SPARK-23898. --- .../sql/catalyst/expressions/arithmetic.scala | 50 ++++++++----------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 478ff3a7c1011..defd6f3cd8849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) @@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) - case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { @@ -104,7 +104,7 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") @@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType - override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") + /** Name of the function for this expression on a [[CalendarInterval]] type. */ + def calendarIntervalMethod: String = + sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case CalendarIntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, @@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" + + override def calendarIntervalMethod: String = "add" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" + override def decimalMethod: String = "$minus" + + override def calendarIntervalMethod: String = "subtract" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "pmod" - protected def checkTypesInternal(t: DataType) = + protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeUtils.checkForNumericExpr(t, "pmod") override def inputType: AbstractDataType = NumericType From 95034af69623bb8be5b9f5eabf50980bdeca48e6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 10 Apr 2018 08:51:35 -0500 Subject: [PATCH 283/593] [SPARK-23841][ML] NodeIdCache should unpersist the last cached nodeIdsForInstances ## What changes were proposed in this pull request? unpersist the last cached nodeIdsForInstances in `deleteAllCheckpoints` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20956 from zhengruifeng/NodeIdCache_cleanup. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index a7c5f489dea86..5b14a63ada4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -95,7 +95,7 @@ private[spark] class NodeIdCache( splits: Array[Array[Split]]): Unit = { if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } prevNodeIdsForInstances = nodeIdsForInstances @@ -166,9 +166,13 @@ private[spark] class NodeIdCache( } } } + if (nodeIdsForInstances != null) { + // Unpersist current one if one exists. + nodeIdsForInstances.unpersist(false) + } if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } } } From 3323b156f9c0beb0b3c2b724a6faddc6ffdfe99a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 10 Apr 2018 17:32:00 +0200 Subject: [PATCH 284/593] [SPARK-23864][SQL] Add unsafe object writing to UnsafeWriter ## What changes were proposed in this pull request? This PR moves writing of `UnsafeRow`, `UnsafeArrayData` & `UnsafeMapData` out of the `GenerateUnsafeProjection`/`InterpretedUnsafeProjection` classes into the `UnsafeWriter` interface. This cleans up the code a little bit, and it should also result in less byte code for the code generated path. ## How was this patch tested? Existing tests Author: Herman van Hovell Closes #20986 from hvanhovell/SPARK-23864. --- .../expressions/codegen/UnsafeWriter.java | 72 ++-- .../InterpretedUnsafeProjection.scala | 46 +-- .../codegen/GenerateUnsafeProjection.scala | 322 ++++++++---------- .../spark/sql/types/UserDefinedType.scala | 10 + 4 files changed, 204 insertions(+), 246 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index de0eb6dbb76be..2781655002000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -103,21 +106,7 @@ protected final void zeroOutPaddingBytes(int numBytes) { public abstract void write(int ordinal, Decimal input, int precision, int scale); public final void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(getBuffer(), cursor()); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - increaseCursor(roundedSize); + writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes()); } public final void write(int ordinal, byte[] input) { @@ -125,20 +114,19 @@ public final void write(int ordinal, byte[] input) { } public final void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes); + } - // grow the global buffer before writing data. + private void writeUnalignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); grow(roundedSize); - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); - + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. increaseCursor(roundedSize); } @@ -156,6 +144,40 @@ public final void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public final void write(int ordinal, UnsafeRow row) { + writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); + } + + public final void write(int ordinal, UnsafeMapData map) { + writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes()); + } + + public final void write(UnsafeArrayData array) { + // Unsafe arrays both can be written as a regular array field or as part of a map. This makes + // updating the offset and size dependent on the code path, this is why we currently do not + // provide an method for writing unsafe arrays that also updates the size and offset. + int numBytes = array.getSizeInBytes(); + grow(numBytes); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + getBuffer(), + cursor(), + numBytes); + increaseCursor(numBytes); + } + + private void writeAlignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + grow(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(numBytes); + } + protected final void writeBoolean(long offset, boolean value) { Platform.putBoolean(getBuffer(), offset, value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index b31466f5c92d1..6d69d69b1c802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => - writeUnsafeData( - rowWriter, - row.getBaseObject, - row.getBaseOffset, - row.getSizeInBytes) + writer.write(i, row) case row => + val previousCursor = writer.cursor() // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. rowWriter.resetRowWriter() structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => @@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => - writeUnsafeData( - valueArrayWriter, - map.getBaseObject, - map.getBaseOffset, - map.getSizeInBytes) + writer.write(i, map) case map => + val previousCursor = writer.cursor() + // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) valueArrayWriter.increaseCursor(8) @@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => @@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementWriter: (SpecializedGetters, Int) => Unit, array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => - writeUnsafeData( - arrayWriter, - unsafe.getBaseObject, - unsafe.getBaseOffset, - unsafe.getSizeInBytes) + arrayWriter.write(unsafe) case _ => val numElements = array.numElements() arrayWriter.initialize(numElements) @@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { i += 1 } } - - /** - * Write an opaque block of data to the buffer. This is used to copy - * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. - */ - private def writeUnsafeData( - writer: UnsafeWriter, - baseObject: AnyRef, - baseOffset: Long, - sizeInBytes: Int) : Unit = { - writer.grow(sizeInBytes) - Platform.copyMemory( - baseObject, - baseOffset, - writer.getBuffer, - writer.cursor, - sizeInBytes) - writer.increaseCursor(sizeInBytes) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4a4d76313a543..2fb441ac4500e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,14 +32,13 @@ import org.apache.spark.sql.types._ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { /** Returns true iff we support this data type. */ - def canSupport(dataType: DataType): Boolean = dataType match { + def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true - case t: AtomicType => true + case _: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -47,6 +46,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeStructToBuffer( ctx: CodegenContext, input: String, + index: String, fieldTypes: Seq[DataType], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. @@ -60,15 +60,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") - + val previousCursor = ctx.freshName("previousCursor") s""" - final InternalRow $tmpInput = $input; - if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} - } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} - } - """ + |final InternalRow $tmpInput = $input; + |if ($tmpInput instanceof UnsafeRow) { + | $rowWriter.write($index, (UnsafeRow) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin } private def writeExpressionsToBuffer( @@ -95,10 +99,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => - val dt = dataType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -106,58 +107,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } - val previousCursor = ctx.freshName("previousCursor") - - val writeField = dt match { - case t: StructType => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$rowWriter.write($index, ${input.value});" - } + val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) if (input.isNull == "false") { s""" - ${input.code} - ${writeField.trim} - """ + |${input.code} + |${writeField.trim} + """.stripMargin } else { s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { - ${writeField.trim} - } - """ + |${input.code} + |if (${input.isNull}) { + | ${setNull.trim} + |} else { + | ${writeField.trim} + |} + """.stripMargin } } @@ -171,11 +136,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro funcName = "writeFields", arguments = Seq("InternalRow" -> row)) } - s""" - $resetWriter - $writeFieldsCode - """.trim + |$resetWriter + |$writeFieldsCode + """.stripMargin } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -189,10 +153,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val et = elementType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val et = UserDefinedType.sqlType(elementType) val jt = CodeGenerator.javaType(et) @@ -205,106 +166,100 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val previousCursor = ctx.freshName("previousCursor") val element = CodeGenerator.getValue(tmpInput, et, index) - val writeElement = et match { - case t: StructType => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$arrayWriter.write($index, $element);" - } - val primitiveTypeName = - if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" - final ArrayData $tmpInput = $input; - if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} - } else { - final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($numElements); - - for (int $index = 0; $index < $numElements; $index++) { - if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - } else { - $writeElement - } - } - } - """ + |final ArrayData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeArrayData) { + | $rowWriter.write((UnsafeArrayData) $tmpInput); + |} else { + | final int $numElements = $tmpInput.numElements(); + | $arrayWriter.initialize($numElements); + | + | for (int $index = 0; $index < $numElements; $index++) { + | if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + | } else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + | } + | } + |} + """.stripMargin } // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, + index: String, keyType: DataType, valueType: DataType, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") + val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final MapData $tmpInput = $input; - if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} - } else { - // preserve 8 bytes to write the key array numBytes later. - $rowWriter.grow(8); - $rowWriter.increaseCursor(8); + |final MapData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeMapData) { + | $rowWriter.write($index, (UnsafeMapData) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | + | // preserve 8 bytes to write the key array numBytes later. + | $rowWriter.grow(8); + | $rowWriter.increaseCursor(8); + | + | // Remember the current cursor so that we can write numBytes of key array later. + | final int $tmpCursor = $rowWriter.cursor(); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | + | // Write the numBytes of key array into the first 8 bytes. + | Platform.putLong( + | $rowWriter.getBuffer(), + | $tmpCursor - 8, + | $rowWriter.cursor() - $tmpCursor); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin + } - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $rowWriter.cursor(); + private def writeElement( + ctx: CodegenContext, + input: String, + index: String, + dt: DataType, + writer: String): String = dt match { + case t: StructType => + writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + + case ArrayType(et, _) => + val previousCursor = ctx.freshName("previousCursor") + s""" + |// Remember the current cursor so that we can calculate how many bytes are + |// written later. + |final int $previousCursor = $writer.cursor(); + |${writeArrayToBuffer(ctx, input, et, writer)} + |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + """.stripMargin - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} - // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); + case MapType(kt, vt, _) => + writeMapToBuffer(ctx, input, index, kt, vt, writer) - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} - } - """ - } + case DecimalType.Fixed(precision, scale) => + s"$writer.write($index, $input, $precision, $scale);" - /** - * If the input is already in unsafe format, we don't need to go through all elements/fields, - * we can directly write it. - */ - private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { - val sizeInBytes = ctx.freshName("sizeInBytes") - s""" - final int $sizeInBytes = $input.getSizeInBytes(); - // grow the global buffer before writing data. - $rowWriter.grow($sizeInBytes); - $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); - $rowWriter.increaseCursor($sizeInBytes); - """ + case NullType => "" + + case _ => s"$writer.write($index, $input);" } def createCode( @@ -332,10 +287,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" - $rowWriter.reset(); - $evalSubexpr - $writeExpressions - """ + |$rowWriter.reset(); + |$evalSubexpr + |$writeExpressions + """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", canDirectAccess = true)) @@ -363,38 +318,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new SpecificUnsafeProjection(references); - } - - class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - - private Object[] references; - ${ctx.declareMutableStates()} - - public SpecificUnsafeProjection(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public void initialize(int partitionIndex) { - ${ctx.initPartition()} - } - - // Scala.Function1 need this - public java.lang.Object apply(java.lang.Object row) { - return apply((InternalRow) row); - } - - public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code.trim} - return ${eval.value}; - } - - ${ctx.declareAddedFunctions()} - } - """ + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificUnsafeProjection(references); + |} + | + |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { + | + | private Object[] references; + | ${ctx.declareMutableStates()} + | + | public SpecificUnsafeProjection(Object[] references) { + | this.references = references; + | ${ctx.initMutableStates()} + | } + | + | public void initialize(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | // Scala.Function1 need this + | public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + | } + | + | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | return ${eval.value}; + | } + | + | ${ctx.declareAddedFunctions()} + |} + """.stripMargin val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5a944e763e099..6af16e2dba105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def catalogString: String = sqlType.simpleString } +private[spark] object UserDefinedType { + /** + * Get the sqlType of a (potential) [[UserDefinedType]]. + */ + def sqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case _ => dt + } +} + /** * The user defined type in Python. * From e179658914963de472120a81621396706584c949 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 10 Apr 2018 09:33:09 -0700 Subject: [PATCH 285/593] [SPARK-19724][SQL][FOLLOW-UP] Check location of managed table when ignoreIfExists is true ## What changes were proposed in this pull request? In the PR #20886, I mistakenly check the table location only when `ignoreIfExists` is false, which was following the original deprecated PR. That was wrong. When `ignoreIfExists` is true and the target table doesn't exist, we should also check the table location. In other word, **`ignoreIfExists` has nothing to do with table location validation**. This is a follow-up PR to fix the mistake. ## How was this patch tested? Add one unit test. Author: Gengliang Wang Closes #21001 from gengliangwang/SPARK-19724-followup. --- .../spark/sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++++-- .../execution/command/createDataSourceTables.scala | 2 +- .../apache/spark/sql/execution/command/DDLSuite.scala | 9 +++++++++ .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 52ed89ef8d781..c390337c03ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -286,7 +286,10 @@ class SessionCatalog( * Create a metastore table in the database specified in `tableDefinition`. * If no such database is specified, create it in the current database. */ - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + def createTable( + tableDefinition: CatalogTable, + ignoreIfExists: Boolean, + validateLocation: Boolean = true): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -305,7 +308,11 @@ class SessionCatalog( } requireDbExists(db) - if (!ignoreIfExists) { + if (tableExists(newTableDefinition.identifier)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistsException(db = db, table = table) + } + } else if (validateLocation) { validateTableLocation(newTableDefinition) } externalCatalog.createTable(newTableDefinition, ignoreIfExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f7c3e9b019258..f6ef433f2ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -182,7 +182,7 @@ case class CreateDataSourceTableAsSelectCommand( // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) // Table location is already validated. No need to check it again during table creation. - sessionState.catalog.createTable(newTable, ignoreIfExists = true) + sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4304d0b6f6b16..cbd7f9d6f67be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -425,6 +425,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") }.getMessage assert(ex.contains(exMsgWithDefaultDB)) + + // Always check location of managed table, with or without (IF NOT EXISTS) + withTable("tab2") { + sql(s"CREATE TABLE tab2 (col1 int, col2 string) USING ${dataSource}") + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE IF NOT EXISTS tab1 LIKE tab2") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } } finally { waitForTasksToFinish() Utils.deleteRecursively(tableLoc) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index db76ec9d084cb..c85db78c732de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1461,7 +1461,7 @@ class HiveDDLSuite assert(e2.getMessage.contains(forbiddenPrefix + "foo")) val e3 = intercept[AnalysisException] { - sql(s"CREATE TABLE tbl (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") + sql(s"CREATE TABLE tbl2 (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") } assert(e3.getMessage.contains(forbiddenPrefix + "foo")) } From adb222b957f327a69929b8f16fa5ebc071fa99e3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Apr 2018 11:18:14 -0700 Subject: [PATCH 286/593] [SPARK-23751][ML][PYSPARK] Kolmogorov-Smirnoff test Python API in pyspark.ml ## What changes were proposed in this pull request? Kolmogorov-Smirnoff test Python API in `pyspark.ml` **Note** API with `CDF` is a little difficult to support in python. We can add it in following PR. ## How was this patch tested? doctest Author: WeichenXu Closes #20904 from WeichenXu123/ks-test-py. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 29 +-- python/pyspark/ml/stat.py | 181 ++++++++++++------ 2 files changed, 138 insertions(+), 72 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index c62d7463288f7..af8ff64d33ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.java.function.Function import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.stat.{Statistics => OldStatistics} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col /** @@ -59,7 +59,7 @@ object KolmogorovSmirnovTest { * distribution of the sample data and the theoretical distribution we can provide a test for the * the null hypothesis that the sample data comes from that theoretical distribution. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value * @return DataFrame containing the test result for the input sampled data. @@ -68,10 +68,10 @@ object KolmogorovSmirnovTest { * - `statistic: Double` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + def test(dataset: Dataset[_], sampleCol: String, cdf: Double => Double): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) @@ -81,10 +81,11 @@ object KolmogorovSmirnovTest { * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, - cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - val f: Double => Double = x => cdf.call(x) - test(dataset, sampleCol, f) + def test( + dataset: Dataset[_], + sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) } /** @@ -92,10 +93,11 @@ object KolmogorovSmirnovTest { * distribution equality. Currently supports the normal distribution, taking as parameters * the mean and standard deviation. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param distName a `String` name for a theoretical distribution, currently only support "norm". - * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @param params `Double*` specifying the parameters to be used for the theoretical distribution. + * For "norm" distribution, the parameters includes mean and variance. * @return DataFrame containing the test result for the input sampled data. * This DataFrame will contain a single Row with the following fields: * - `pValue: Double` @@ -103,10 +105,13 @@ object KolmogorovSmirnovTest { */ @Since("2.4.0") @varargs - def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + def test( + dataset: Dataset[_], + sampleCol: String, distName: String, + params: Double*): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 0eeb5e528434a..93d0f4fd9148f 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -32,32 +32,6 @@ class ChiSquareTest(object): The null hypothesis is that the occurrence of the outcomes is statistically independent. - :param dataset: - DataFrame of categorical labels and categorical features. - Real-valued features will be treated as categorical for each distinct value. - :param featuresCol: - Name of features column in dataset, of type `Vector` (`VectorUDT`). - :param labelCol: - Name of label column in dataset, of any numerical type. - :return: - DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: - - `pValues: Vector` - - `degreesOfFreedom: Array[Int]` - - `statistics: Vector` - Each of these fields has one value per feature. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import ChiSquareTest - >>> dataset = [[0, Vectors.dense([0, 0, 1])], - ... [0, Vectors.dense([1, 0, 1])], - ... [1, Vectors.dense([2, 1, 1])], - ... [1, Vectors.dense([3, 1, 1])]] - >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) - >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') - >>> chiSqResult.select("degreesOfFreedom").collect()[0] - Row(degreesOfFreedom=[3, 1, 0]) - .. versionadded:: 2.2.0 """ @@ -66,6 +40,32 @@ class ChiSquareTest(object): def test(dataset, featuresCol, labelCol): """ Perform a Pearson's independence test using dataset. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest @@ -85,40 +85,6 @@ class Correlation(object): which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` to avoid recomputing the common lineage. - :param dataset: - A dataset or a dataframe. - :param column: - The name of the column of vectors for which the correlation coefficient needs - to be computed. This must be a column of the dataset, and it must contain - Vector objects. - :param method: - String specifying the method to use for computing correlation. - Supported: `pearson` (default), `spearman`. - :return: - A dataframe that contains the correlation matrix of the column of vectors. This - dataframe contains a single row and a single column of name - '$METHODNAME($COLUMN)'. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import Correlation - >>> dataset = [[Vectors.dense([1, 0, 0, -2])], - ... [Vectors.dense([4, 5, 0, 3])], - ... [Vectors.dense([6, 7, 0, 8])], - ... [Vectors.dense([9, 0, 0, 1])]] - >>> dataset = spark.createDataFrame(dataset, ['features']) - >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] - >>> print(str(pearsonCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], - [ 0.0556..., 1. , NaN, 0.9135...], - [ NaN, NaN, 1. , NaN], - [ 0.4004..., 0.9135..., NaN, 1. ]]) - >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] - >>> print(str(spearmanCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], - [ 0.1054..., 1. , NaN, 0.9486... ], - [ NaN, NaN, 1. , NaN], - [ 0.4 , 0.9486... , NaN, 1. ]]) - .. versionadded:: 2.2.0 """ @@ -127,6 +93,40 @@ class Correlation(object): def corr(dataset, column, method="pearson"): """ Compute the correlation matrix with specified method using dataset. + + :param dataset: + A Dataset or a DataFrame. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A DataFrame that contains the correlation matrix of the column of vectors. This + DataFrame contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) """ sc = SparkContext._active_spark_context javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation @@ -134,6 +134,67 @@ def corr(dataset, column, method="pearson"): return _java2py(sc, javaCorrObj.corr(*args)) +class KolmogorovSmirnovTest(object): + """ + .. note:: Experimental + + Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous + distribution. + + By comparing the largest difference between the empirical cumulative + distribution of the sample data and the theoretical distribution we can provide a test for the + the null hypothesis that the sample data comes from that theoretical distribution. + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def test(dataset, sampleCol, distName, *params): + """ + Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution + equality. Currently supports the normal distribution, taking as parameters the mean and + standard deviation. + + :param dataset: + a Dataset or a DataFrame containing the sample of data to test. + :param sampleCol: + Name of sample column in dataset, of any numerical type. + :param distName: + a `string` name for a theoretical distribution, currently only support "norm". + :param params: + a list of `Double` values specifying the parameters to be used for the theoretical + distribution. For "norm" distribution, the parameters includes mean and variance. + :return: + A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data. + This DataFrame will contain a single Row with the following fields: + - `pValue: Double` + - `statistic: Double` + + >>> from pyspark.ml.stat import KolmogorovSmirnovTest + >>> dataset = [[-1.0], [0.0], [1.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + >>> dataset = [[2.0], [3.0], [4.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest + dataset = _py2java(sc, dataset) + params = [float(param) for param in params] + return _java2py(sc, javaTestObj.test(dataset, sampleCol, distName, + _jvm().PythonUtils.toSeq(params))) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 4f1e8b9bb7d795d4ca3d5cd5dcc0f9419e52dfae Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 10 Apr 2018 15:41:45 -0700 Subject: [PATCH 287/593] [SPARK-23871][ML][PYTHON] add python api for VectorAssembler handleInvalid ## What changes were proposed in this pull request? add python api for VectorAssembler handleInvalid ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #21003 from huaxingao/spark-23871. --- .../spark/ml/feature/VectorAssembler.scala | 12 +++--- python/pyspark/ml/feature.py | 42 ++++++++++++++++--- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 6bf4aa38b1fcb..4061154b39c14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) */ @Since("2.4.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with - |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the - |output). Column lengths are taken from the size of ML Attribute Group, which can be set using - |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred - |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. - |""".stripMargin.replaceAll("\n", " "), + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5a3e0dd655150..cdda30cfab482 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ A feature transformer that merges multiple columns into a vector column. @@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath) >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs True + >>> dfWithNullsAndNaNs = spark.createDataFrame( + ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"]) + >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features", + ... handleInvalid="keep") + >>> vecAssembler2.transform(dfWithNullsAndNaNs).show() + +---+---+----+-------------+ + | a| b| c| features| + +---+---+----+-------------+ + |1.0|2.0|null|[1.0,2.0,NaN]| + |3.0|NaN| 4.0|[3.0,NaN,4.0]| + |5.0|6.0| 7.0|[5.0,6.0,7.0]| + +---+---+----+-------------+ + ... + >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show() + +---+---+---+-------------+ + | a| b| c| features| + +---+---+---+-------------+ + |5.0|6.0|7.0|[5.0,6.0,7.0]| + +---+---+---+-------------+ + ... .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " + + "and NaN values). Options are 'skip' (filter out rows with invalid " + + "data), 'error' (throw an error), or 'keep' (return relevant number " + + "of NaN in the output). Column lengths are taken from the size of ML " + + "Attribute Group, which can be set using `VectorSizeHint` in a " + + "pipeline before `VectorAssembler`. Column lengths can also be " + + "inferred from first rows of the data since it is safe to do so but " + + "only in case of 'error' or 'skip').", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCols=None, outputCol=None): + def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCols=None, outputCol=None) + __init__(self, inputCols=None, outputCol=None, handleInvalid="error") """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) + self._setDefault(handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCols=None, outputCol=None): + def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCols=None, outputCol=None) + setParams(self, inputCols=None, outputCol=None, handleInvalid="error") Sets params for this VectorAssembler. """ kwargs = self._input_kwargs From 7c7570d466a8ded51e580eb6a28583bd9a9c5337 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 10 Apr 2018 17:26:06 -0700 Subject: [PATCH 288/593] [SPARK-23944][ML] Add the set method for the two LSHModel ## What changes were proposed in this pull request? Add two set method for LSHModel in LSH.scala, BucketedRandomProjectionLSH.scala, and MinHashLSH.scala ## How was this patch tested? New test for the param setup was added into - BucketedRandomProjectionLSHSuite.scala - MinHashLSHSuite.scala Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21015 from ludatabricks/SPARK-23944. --- .../spark/ml/feature/BucketedRandomProjectionLSH.scala | 8 ++++++++ .../src/main/scala/org/apache/spark/ml/feature/LSH.scala | 6 ++++++ .../scala/org/apache/spark/ml/feature/MinHashLSH.scala | 8 ++++++++ .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 8 ++++++++ .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 8 ++++++++ 5 files changed, 38 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 36a46ca6ff4b7..41eaaf9679914 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml]( private[ml] val randUnitVectors: Array[Vector]) extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { key: Vector => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 1c9f47a0b201d..a70931f783f45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHParams with MLWritable { self: T => + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 145422a059196..556848e45532d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -51,6 +51,14 @@ class MinHashLSHModel private[ml]( private[ml] val randCoefficients: Array[(Int, Int)]) extends LSHModel[MinHashLSHModel] { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { elems: Vector => { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ed9a39d8d1512..9b823259b1deb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest ParamsSuite.checkParams(model) } + test("setters") { + val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("BucketedRandomProjectionLSH: default params") { val brp = new BucketedRandomProjectionLSH assert(brp.getNumHashTables === 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 96df68dbdf053..3da0fb7da01ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa ParamsSuite.checkParams(model) } + test("setters") { + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("MinHashLSH: default params") { val rp = new MinHashLSH assert(rp.getNumHashTables === 1.0) From c7622befdadfea725797d76e820e3dfc76fec927 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:42:09 +0800 Subject: [PATCH 289/593] [SPARK-23847][FOLLOWUP][PYTHON][SQL] Actually test [desc|acs]_nulls_[first|last] functions in PySpark ## What changes were proposed in this pull request? There was a mistake in `tests.py` missing `assertEquals`. ## How was this patch tested? Fixed tests. Author: hyukjinkwon Closes #21035 from HyukjinKwon/SPARK-23847. --- python/pyspark/sql/tests.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dd04ffb4ed393..96c2a776a5049 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,19 +2991,23 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() - def test_2_4_functions(self): + def test_sort_with_nulls_order(self): from pyspark.sql import functions df = self.spark.createDataFrame( [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) - df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] - df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] - df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] - df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) class HiveSparkSubmitTests(SparkSubmitTests): From 87611bba222a95158fc5b638a566bdf47346da8e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:44:01 +0800 Subject: [PATCH 290/593] [MINOR][DOCS] Fix R documentation generation instruction for roxygen2 ## What changes were proposed in this pull request? This PR proposes to fix `roxygen2` to `5.0.1` in `docs/README.md` for SparkR documentation generation. If I use higher version and creates the doc, it shows the diff below. Not a big deal but it bothered me. ```diff diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f..159fca61e06 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION -57,6 +57,6 Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 5.0.1 +RoxygenNote: 6.0.1 VignetteBuilder: knitr NeedsCompilation: no ``` ## How was this patch tested? Manually tested. I met this every time I set the new environment for Spark dev but I have kept forgetting to fix it. Author: hyukjinkwon Closes #21020 from HyukjinKwon/minor-r-doc. --- docs/README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/README.md b/docs/README.md index 9eac4ba35c458..dbea4d64c4298 100644 --- a/docs/README.md +++ b/docs/README.md @@ -22,10 +22,13 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' ``` -(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) +Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. + +Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 5.0.1, which is updated if the version is mismatched. ## Generating the Documentation HTML @@ -62,12 +65,12 @@ $ PRODUCTION=1 jekyll build ## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs) -You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory. +You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory. Similarly, you can build just the PySpark docs by running `make html` from the -`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and -the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh` +`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as +public in `__init__.py`. The SparkR docs can be built by running `$SPARK_HOME/R/create-docs.sh`, and +the SQL docs can be built by running `$SPARK_HOME/sql/create-docs.sh` after [building Spark](https://github.com/apache/spark#building-spark) first. When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various From c604d659e19c1b2704cdf8c8ea97edaf50d8cb6b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 11 Apr 2018 20:11:03 +0800 Subject: [PATCH 291/593] [SPARK-23951][SQL] Use actual java class instead of string representation. ## What changes were proposed in this pull request? This PR slightly refactors the newly added `ExprValue` API by quite a bit. The following changes are introduced: 1. `ExprValue` now uses the actual class instead of the class name as its type. This should give some more flexibility with generating code in the future. 2. Renamed `StatementValue` to `SimpleExprValue`. The statement concept is broader then an expression (untyped and it cannot be on the right hand side of an assignment), and this was not really what we were using it for. I have added a top level `JavaCode` trait that can be used in the future to reinstate (no pun intended) a statement a-like code fragment. 3. Added factory methods to the `JavaCode` companion object to make it slightly less verbose to create `JavaCode`/`ExprValue` objects. This is also what makes the diff quite large. 4. Added one more factory method to `ExprCode` to make it easier to create code-less expressions. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21026 from hvanhovell/SPARK-23951. --- .../sql/catalyst/expressions/Expression.scala | 10 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 35 +++- .../expressions/codegen/ExprValue.scala | 76 -------- .../codegen/GenerateMutableProjection.scala | 36 ++-- .../codegen/GenerateSafeProjection.scala | 25 +-- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/codegen/javaCode.scala | 166 ++++++++++++++++++ .../expressions/complexTypeCreator.scala | 2 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/literals.scala | 27 +-- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/nullExpressions.scala | 20 +-- .../expressions/objects/objects.scala | 7 +- .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/codegen/ExprValueSuite.scala | 14 +- .../sql/execution/ColumnarBatchScan.scala | 6 +- .../spark/sql/execution/ExpandExec.scala | 8 +- .../spark/sql/execution/GenerateExec.scala | 15 +- .../sql/execution/WholeStageCodegenExec.scala | 11 +- .../aggregate/HashAggregateExec.scala | 6 +- .../aggregate/HashMapGenerator.scala | 8 +- .../execution/basicPhysicalOperators.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 9 +- .../execution/joins/SortMergeJoinExec.scala | 12 +- 26 files changed, 315 insertions(+), 212 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7a5e49cb5206b..97dff6ae88299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(dataType)))) + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) + eval.isNull = JavaCode.isNullGlobal(globalIsNull) s"$globalIsNull = $localIsNull;" } else { "" @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = VariableValue(newValue, javaType) + eval.value = JavaCode.variable(newValue, dataType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index defd6f3cd8849..9212c3de1f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c9c60ef1be640..0abfc9fa4c465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { + def apply(isNull: ExprValue, value: ExprValue): ExprCode = { + ExprCode(code = "", isNull, value) + } + def forNullValue(dataType: DataType): ExprCode = { - val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = TrueLiteral, - value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) + ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { @@ -331,7 +333,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) + ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } def declareMutableStates(): String = { @@ -1004,8 +1006,9 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), - GlobalValue(value, javaType(expr.dataType))) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType)) subExprEliminationExprs ++= e.map(_ -> state).toMap } } @@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging { case _ => "Object" } + def javaClass(dt: DataType): Class[_] = dt match { + case BooleanType => java.lang.Boolean.TYPE + case ByteType => java.lang.Byte.TYPE + case ShortType => java.lang.Short.TYPE + case IntegerType | DateType => java.lang.Integer.TYPE + case LongType | TimestampType => java.lang.Long.TYPE + case FloatType => java.lang.Float.TYPE + case DoubleType => java.lang.Double.TYPE + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case udt: UserDefinedType[_] => javaClass(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[Object] + } + /** * Returns the boxed type in Java. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala deleted file mode 100644 index df5f1c58b1b2d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import scala.language.implicitConversions - -import org.apache.spark.sql.types.DataType - -// An abstraction that represents the evaluation result of [[ExprCode]]. -abstract class ExprValue { - - val javaType: String - - // Whether we can directly access the evaluation value anywhere. - // For example, a variable created outside a method can not be accessed inside the method. - // For such cases, we may need to pass the evaluation as parameter. - val canDirectAccess: Boolean - - def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) -} - -object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString -} - -// A literal evaluation of [[ExprCode]]. -class LiteralValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -object LiteralValue { - def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) - def unapply(literal: LiteralValue): Option[(String, String)] = - Some((literal.value, literal.javaType)) -} - -// A variable evaluation of [[ExprCode]]. -case class VariableValue( - val variableName: String, - val javaType: String) extends ExprValue { - override def toString: String = variableName - override val canDirectAccess: Boolean = false -} - -// A statement evaluation of [[ExprCode]]. -case class StatementValue( - val statement: String, - val javaType: String, - val canDirectAccess: Boolean = false) extends ExprValue { - override def toString: String = statement -} - -// A global variable evaluation of [[ExprCode]]. -case class GlobalValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -case object TrueLiteral extends LiteralValue("true", "boolean") -case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3ae0b54c754cf..33d14329ec95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() - val (validExpr, index) = expressions.zipWithIndex.filter { + val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true - }.unzip - val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + } + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { - case (ev, i) => - val e = expressions(i) - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") - if (e.nullable) { + val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { + case ((e, i), ev) => + val value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"), + e.dataType) + val (code, isNull) = if (e.nullable) { val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) + """.stripMargin, JavaCode.isNullGlobal(isNull)) } else { (s""" |${ev.code} |$value = ${ev.value}; - """.stripMargin, ev.isNull, value, i) + """.stripMargin, FalseLiteral) } + val update = CodeGenerator.updateColumn( + "mutableRow", + e.dataType, + i, + ExprCode(isNull, value), + e.nullable) + (code, update) } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(projectionCodes).map { - case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) - CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) - } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) - val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index a30a0b22cd305..01c350e9dbf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt)), dt) + val converter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), + dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow])) } private def createCodeForArray( @@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), - CodeGenerator.javaType(elementType)), elementType) + val elementConverter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), + elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData])) } private def createCodeForMap( @@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData])) } @tailrec @@ -140,7 +143,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", FalseLiteral, input) + case _ => ExprCode(FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2fb441ac4500e..01b4d6c4529bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt))) + ExprCode( + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -109,7 +109,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == "false") { + if (input.isNull == FalseLiteral) { s""" |${input.code} |${writeField.trim} @@ -292,8 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro |$writeExpressions """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. - ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", - canDirectAccess = true)) + ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow])) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala new file mode 100644 index 0000000000000..74ff018488863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import java.lang.{Boolean => JBool} + +import scala.language.{existentials, implicitConversions} + +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Trait representing an opaque fragments of java code. + */ +trait JavaCode { + def code: String + override def toString: String = code +} + +/** + * Utility functions for creating [[JavaCode]] fragments. + */ +object JavaCode { + /** + * Create a java literal. + */ + def literal(v: String, dataType: DataType): LiteralValue = dataType match { + case BooleanType if v == "true" => TrueLiteral + case BooleanType if v == "false" => FalseLiteral + case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a default literal. This is null for reference types, false for boolean types and + * -1 for other primitive types. + */ + def defaultLiteral(dataType: DataType): LiteralValue = { + new LiteralValue( + CodeGenerator.defaultValue(dataType, typedNull = true), + CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, dataType: DataType): VariableValue = { + variable(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, javaClass: Class[_]): VariableValue = { + VariableValue(name, javaClass) + } + + /** + * Create a local isNull variable. + */ + def isNullVariable(name: String): VariableValue = variable(name, BooleanType) + + /** + * Create a global java variable. + */ + def global(name: String, dataType: DataType): GlobalValue = { + global(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a global java variable. + */ + def global(name: String, javaClass: Class[_]): GlobalValue = { + GlobalValue(name, javaClass) + } + + /** + * Create a global isNull variable. + */ + def isNullGlobal(name: String): GlobalValue = global(name, BooleanType) + + /** + * Create an expression fragment. + */ + def expression(code: String, dataType: DataType): SimpleExprValue = { + expression(code, CodeGenerator.javaClass(dataType)) + } + + /** + * Create an expression fragment. + */ + def expression(code: String, javaClass: Class[_]): SimpleExprValue = { + SimpleExprValue(code, javaClass) + } + + /** + * Create a isNull expression fragment. + */ + def isNullExpression(code: String): SimpleExprValue = { + expression(code, BooleanType) + } +} + +/** + * A typed java fragment that must be a valid java expression. + */ +trait ExprValue extends JavaCode { + def javaType: Class[_] + def isPrimitive: Boolean = javaType.isPrimitive +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + + +/** + * A java expression fragment. + */ +case class SimpleExprValue(expr: String, javaType: Class[_]) extends ExprValue { + override def code: String = s"($expr)" +} + +/** + * A local variable java expression. + */ +case class VariableValue(variableName: String, javaType: Class[_]) extends ExprValue { + override def code: String = variableName +} + +/** + * A global variable java expression. + */ +case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { + override def code: String = value +} + +/** + * A literal java expression. + */ +class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { + override def code: String = value + + override def equals(arg: Any): Boolean = arg match { + case l: LiteralValue => l.javaType == javaType && l.value == value + case _ => false + } + + override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() +} + +case object TrueLiteral extends LiteralValue("true", JBool.TYPE) +case object FalseLiteral extends LiteralValue("false", JBool.TYPE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 49a8d12057188..67876a8565488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 409c0b6b79b81..205d77f6a9acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,8 +191,9 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), - CodeGenerator.javaType(dataType)) + ev.value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + dataType) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 49dd988b4b53c..32fdb13afbbfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,8 +813,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", - CodeGenerator.javaType(dataType))) + ExprCode.forNullValue(StringType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 742a650eb445d..246025b82d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -281,38 +281,41 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { if (value == null) { ExprCode.forNullValue(dataType) } else { + def toExprCode(code: String): ExprCode = { + ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) + } dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) + toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) + toExprCode("Float.NaN") case Float.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) + toExprCode("Float.POSITIVE_INFINITY") case Float.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) + toExprCode("Float.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) + toExprCode(s"${value}F") } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) + toExprCode("Double.NaN") case Double.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) + toExprCode("Double.POSITIVE_INFINITY") case Double.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) + toExprCode("Double.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) + toExprCode(s"${value}D") } case ByteType | ShortType => - ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) + ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) case TimestampType | LongType => - ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) + toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) + ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7081a5e096d56..7eda65a867028 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -92,7 +92,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, - value = LiteralValue("null", CodeGenerator.javaType(dataType))) + value = JavaCode.defaultLiteral(dataType)) } override def sql: String = s"assert_true(${child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 55b6e346be82a..0787342bce6bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -321,12 +320,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } else { - VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } - ExprCode(code = eval.code, isNull = FalseLiteral, value = value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -352,12 +346,10 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull == TrueLiteral) { - FalseLiteral - } else if (eval.isNull == FalseLiteral) { - TrueLiteral - } else { - StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + val value = eval.isNull match { + case TrueLiteral => FalseLiteral + case FalseLiteral => TrueLiteral + case v => JavaCode.isNullExpression(s"!$v") } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b2cca3178cd2a..50e90ca550807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -65,7 +65,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullGlobal(resultIsNull) } else { FalseLiteral } @@ -569,12 +569,11 @@ case class LambdaVariable( override def genCode(ctx: CodegenContext): ExprCode = { val isNullValue = if (nullable) { - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } else { FalseLiteral } - ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), - isNull = isNullValue) + ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8e83b35c3809c..f7c023111ff59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,8 +448,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) - val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), - VariableValue("dummy", "boolean")) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) // raw testing of basic functionality { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index c8f4cff7db48d..378b8bc055e34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.BooleanType class ExprValueSuite extends SparkFunSuite { @@ -31,16 +32,7 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit.isPrimitive) assert(falseLit.isPrimitive) - trueLit match { - case LiteralValue(value, javaType) => - assert(value == "true" && javaType == "boolean") - case _ => fail() - } - - falseLit match { - case LiteralValue(value, javaType) => - assert(value == "false" && javaType == "boolean") - case _ => fail() - } + assert(trueLit === JavaCode.literal("true", BooleanType)) + assert(falseLit === JavaCode.literal("false", BooleanType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 434214a10e1e3..fc3dbc1c5591b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -52,7 +52,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(ctx.freshName("isNull")) } else { FalseLiteral } @@ -66,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 0d9a62cace62a..e4812f3d338fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,8 +157,10 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) + ExprCode( + code, + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, firstExpr.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 85c5ebfdaa689..f40c50df74ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types._ /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -170,10 +170,11 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), - VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode( + JavaCode.isNullExpression(s"$index == -1"), + JavaCode.variable(index, IntegerType))) } else { - Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode(FalseLiteral, JavaCode.variable(index, IntegerType))) } } else { Seq.empty @@ -316,11 +317,9 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, javaType)) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, - VariableValue(value, javaType)) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 805ff3cf001ba..828b51fa199de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) + ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -128,8 +128,8 @@ trait CodegenSupport extends SparkPlan { """.stripMargin.trim ExprCode(code, FalseLiteral, ev.value) } else { - // There is no columns - ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) + // There are no columns + ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow])) } } } @@ -246,11 +246,10 @@ trait CodegenSupport extends SparkPlan { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } - paramVars += ExprCode("", paramIsNull, - VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) + paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType)) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f7f10243d4cc..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,8 +194,10 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 4978954271311..de2d630de3fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ /** @@ -54,8 +54,10 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index cab7081400ce9..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) + val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fa62a32d51f3e..6fa716d9fadee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, LongType} import org.apache.spark.util.TaskCompletionListener /** @@ -192,8 +192,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) } } } @@ -488,8 +487,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", FalseLiteral, - VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) + val resultVar = input ++ Seq(ExprCode.forNonNullValue( + JavaCode.variable(existsVar, BooleanType))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b61acb8d4fda9..d8261f0f33b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,11 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, -ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -531,13 +530,12 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), + leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, FalseLiteral, - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } }.unzip } From 271c891b91917d660d1f6b995de397c47c7a6058 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 11 Apr 2018 21:52:48 +0800 Subject: [PATCH 292/593] [SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient ## What changes were proposed in this pull request? Mark `HashAggregateExec.bufVars` as transient to avoid it from being serialized. Also manually null out this field at the end of `doProduceWithoutKeys()` to shorten its lifecycle, because it'll no longer be used after that. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21039 from rednaxelafx/codegen-improve. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..965950ed94fe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. + @transient private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,6 +238,8 @@ case class HashAggregateExec( | } """.stripMargin) + bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner + val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 653fe02415a537299e15f92b56045569864b6183 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 09:49:25 -0500 Subject: [PATCH 293/593] [SPARK-6951][CORE] Speed up parsing of event logs during listing. This change introduces two optimizations to help speed up generation of listing data when parsing events logs. The first one allows the parser to be stopped when enough data to create the listing entry has been read. This is currently the start event plus environment info, to capture UI ACLs. If the end event is needed, the code will skip to the end of the log to try to find that information, instead of parsing the whole log file. Unfortunately this works better with uncompressed logs. Skipping bytes on compressed logs only saves the work of parsing lines and some events, so not a lot of gains are observed. The second optimization deals with in-progress logs. It works in two ways: first, it completely avoids parsing the rest of the log for these apps when enough listing data is read. This, unlike the above, also speeds things up for compressed logs, since only the very beginning of the log has to be read. On top of that, the code that decides whether to re-parse logs to get updated listing data will ignore in-progress applications until they've completed. Both optimizations can be disabled but are enabled by default. I tested this on some fake event logs to see the effect. I created 500 logs of about 60M each (so ~30G uncompressed; each log was 1.7M when compressed with zstd). Below, C = completed, IP = in-progress, the size means the amount of data re-parsed at the end of logs when necessary. ``` none/C none/IP zstd/C zstd/IP On / 16k 2s 2s 22s 2s On / 1m 3s 2s 24s 2s Off 1.1m 1.1m 26s 24s ``` This was with 4 threads on a single local SSD. As expected from the previous explanations, there are considerable gains for in-progress logs, and for uncompressed logs, but not so much when looking at the full compressed log. As a side note, I removed the custom code to get the scan time by creating a file on HDFS; since file mod times are not used to detect changed logs anymore, local time is enough for the current use of the SHS. Author: Marcelo Vanzin Closes #20952 from vanzin/SPARK-6951. --- .../deploy/history/FsHistoryProvider.scala | 251 ++++++++++++------ .../apache/spark/deploy/history/config.scala | 15 ++ .../spark/scheduler/ReplayListenerBus.scala | 11 + .../org/apache/spark/util/ListenerBus.scala | 5 +- .../history/FsHistoryProviderSuite.scala | 78 ++++-- 5 files changed, 264 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ace6d9e00c838..56db9359e033f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, ServiceLoader, UUID} +import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.io.Source import scala.util.Try import scala.xml.Node @@ -58,10 +59,10 @@ import org.apache.spark.util.kvstore._ * * == How new and updated attempts are detected == * - * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any - * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new attempt info entry - * and update or create a matching application info element in the list of applications. + * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any entries in the + * log dir whose size changed since the last scan time are considered new or updated. These are + * replayed to create a new attempt info entry and update or create a matching application info + * element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. * @@ -125,6 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) + private val fastInProgressParsing = conf.get(FAST_IN_PROGRESS_PARSING) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => @@ -402,13 +404,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { - val newLastScanTime = getNewLastScanTime() + val newLastScanTime = clock.getTimeMillis() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // FsHistoryProvider used to generate a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && @@ -417,15 +419,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .filter { entry => try { val info = listing.read(classOf[LogInfo], entry.getPath().toString()) - if (info.fileSize < entry.getLen()) { - // Log size has changed, it should be parsed. - true - } else { + + if (info.appId.isDefined) { // If the SHS view has a valid application, update the time the file was last seen so - // that the entry is not deleted from the SHS listing. - if (info.appId.isDefined) { - listing.write(info.copy(lastProcessed = newLastScanTime)) + // that the entry is not deleted from the SHS listing. Also update the file size, in + // case the code below decides we don't need to parse the log. + listing.write(info.copy(lastProcessed = newLastScanTime, fileSize = entry.getLen())) + } + + if (info.fileSize < entry.getLen()) { + if (info.appId.isDefined && fastInProgressParsing) { + // When fast in-progress parsing is on, we don't need to re-parse when the + // size changes, but we do need to invalidate any existing UIs. + invalidateUI(info.appId.get, info.attemptId) + false + } else { + true } + } else { false } } catch { @@ -449,7 +460,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val tasks = updated.map { entry => try { replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) }) } catch { // let the iteration over the updated entries break, since an exception on @@ -542,25 +553,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - private[history] def getNewLastScanTime(): Long = { - val fileName = "." + UUID.randomUUID().toString - val path = new Path(logDir, fileName) - val fos = fs.create(path) - - try { - fos.close() - fs.getFileStatus(path).getModificationTime - } catch { - case e: Exception => - logError("Exception encountered when attempting to update last scan time", e) - lastScanTime.get() - } finally { - if (!fs.delete(path, true)) { - logWarning(s"Error deleting ${path}") - } - } - } - override def writeEventLogs( appId: String, attemptId: Option[String], @@ -607,7 +599,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { + protected def mergeApplicationListing( + fileStatus: FileStatus, + scanTime: Long, + enableOptimizations: Boolean): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -616,32 +611,118 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() + val appCompleted = isCompleted(logPath.getName()) + val reparseChunkSize = conf.get(END_EVENT_REPARSE_CHUNK_SIZE) + + // Enable halt support in listener if: + // - app in progress && fast parsing enabled + // - skipping to end event is enabled (regardless of in-progress state) + val shouldHalt = enableOptimizations && + ((!appCompleted && fastInProgressParsing) || reparseChunkSize > 0) + val bus = new ReplayListenerBus() - val listener = new AppListingListener(fileStatus, clock) + val listener = new AppListingListener(fileStatus, clock, shouldHalt) bus.addListener(listener) - replay(fileStatus, bus, eventsFilter = eventsFilter) - - val (appId, attemptId) = listener.applicationInfo match { - case Some(app) => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + + logInfo(s"Parsing $logPath for listing data...") + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !appCompleted, eventsFilter) + } + + // If enabled above, the listing listener will halt parsing when there's enough information to + // create a listing entry. When the app is completed, or fast parsing is disabled, we still need + // to replay until the end of the log file to try to find the app end event. Instead of reading + // and parsing line by line, this code skips bytes from the underlying stream so that it is + // positioned somewhere close to the end of the log file. + // + // Because the application end event is written while some Spark subsystems such as the + // scheduler are still active, there is no guarantee that the end event will be the last + // in the log. So, to be safe, the code uses a configurable chunk to be re-parsed at + // the end of the file, and retries parsing the whole log later if the needed data is + // still not found. + // + // Note that skipping bytes in compressed files is still not cheap, but there are still some + // minor gains over the normal log parsing done by the replay bus. + // + // This code re-opens the file so that it knows where it's skipping to. This isn't as cheap as + // just skipping from the current position, but there isn't a a good way to detect what the + // current position is, since the replay listener bus buffers data internally. + val lookForEndEvent = shouldHalt && (appCompleted || !fastInProgressParsing) + if (lookForEndEvent && listener.applicationInfo.isDefined) { + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + val target = fileStatus.getLen() - reparseChunkSize + if (target > 0) { + logInfo(s"Looking for end event; skipping $target bytes from $logPath...") + var skipped = 0L + while (skipped < target) { + skipped += in.skip(target - skipped) } } + val source = Source.fromInputStream(in).getLines() + + // Because skipping may leave the stream in the middle of a line, read the next line + // before replaying. + if (target > 0) { + source.next() + } + + bus.replay(source, logPath.toString, !appCompleted, eventsFilter) + } + } + + logInfo(s"Finished parsing $logPath") + + listener.applicationInfo match { + case Some(app) if !lookForEndEvent || app.attempts.head.info.completed => + // In this case, we either didn't care about the end event, or we found it. So the + // listing data is good. + invalidateUI(app.info.id, app.attempts.head.info.attemptId) addListing(app) - (Some(app.info.id), app.attempts.head.info.attemptId) + listing.write(LogInfo(logPath.toString(), scanTime, Some(app.info.id), + app.attempts.head.info.attemptId, fileStatus.getLen())) + + // For a finished log, remove the corresponding "in progress" entry from the listing DB if + // the file is really gone. + if (appCompleted) { + val inProgressLog = logPath.toString() + EventLoggingListener.IN_PROGRESS + try { + // Fetch the entry first to avoid an RPC when it's already removed. + listing.read(classOf[LogInfo], inProgressLog) + if (!fs.isFile(new Path(inProgressLog))) { + listing.delete(classOf[LogInfo], inProgressLog) + } + } catch { + case _: NoSuchElementException => + } + } + + case Some(_) => + // In this case, the attempt is still not marked as finished but was expected to. This can + // mean the end event is before the configured threshold, so call the method again to + // re-parse the whole log. + logInfo(s"Reparsing $logPath since end event was not found.") + mergeApplicationListing(fileStatus, scanTime, false) case _ => // If the app hasn't written down its app ID to the logs, still record the entry in the // listing db, with an empty ID. This will make the log eligible for deletion if the app // does not make progress after the configured max log age. - (None, None) + listing.write(LogInfo(logPath.toString(), scanTime, None, None, fileStatus.getLen())) + } + } + + /** + * Invalidate an existing UI for a given app attempt. See LoadedAppUI for a discussion on the + * UI lifecycle. + */ + private def invalidateUI(appId: String, attemptId: Option[String]): Unit = { + synchronized { + activeUIs.get((appId, attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** @@ -696,29 +777,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. - * `ReplayEventsFilter` determines what events are replayed. - */ - private def replay( - eventLog: FileStatus, - bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { - val logPath = eventLog.getPath() - val isCompleted = !logPath.getName().endsWith(EventLoggingListener.IN_PROGRESS) - logInfo(s"Replaying log path: $logPath") - // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, - // and when we read the file here. That is OK -- it may result in an unnecessary refresh - // when there is no update, but will not result in missing an update. We *must* prevent - // an error the other way -- if we report a size bigger (ie later) than the file that is - // actually read, we may never refresh the app. FileStatus is guaranteed to be static - // after it's created, so we get a file size that is no bigger than what is actually read. - Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => - bus.replay(in, logPath.toString, !isCompleted, eventsFilter) - logInfo(s"Finished parsing $logPath") - } - } - /** * Rebuilds the application state store from its event log. */ @@ -741,8 +799,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } replayBus.addListener(listener) try { - replay(eventLog, replayBus) + val path = eventLog.getPath() + logInfo(s"Parsing $path to re-build UI...") + Utils.tryWithResource(EventLoggingListener.openEventLog(path, fs)) { in => + replayBus.replay(in, path.toString(), maybeTruncated = !isCompleted(path.toString())) + } trackingStore.close(false) + logInfo(s"Finished parsing $path") } catch { case e: Exception => Utils.tryLogNonFatalError { @@ -881,6 +944,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + private def isCompleted(name: String): Boolean = { + !name.endsWith(EventLoggingListener.IN_PROGRESS) + } + } private[history] object FsHistoryProvider { @@ -945,11 +1012,17 @@ private[history] class ApplicationInfoWrapper( } -private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { +private[history] class AppListingListener( + log: FileStatus, + clock: Clock, + haltEnabled: Boolean) extends SparkListener { private val app = new MutableApplicationInfo() private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + private var gotEnvUpdate = false + private var halted = false + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { app.id = event.appId.orNull app.name = event.appName @@ -958,6 +1031,8 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.startTime = new Date(event.time) attempt.lastUpdated = new Date(clock.getTimeMillis()) attempt.sparkUser = event.sparkUser + + checkProgress() } override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { @@ -968,11 +1043,18 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { - val allProperties = event.environmentDetails("Spark Properties").toMap - attempt.viewAcls = allProperties.get("spark.ui.view.acls") - attempt.adminAcls = allProperties.get("spark.admin.acls") - attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + // Only parse the first env update, since any future changes don't have any effect on + // the ACLs set for the UI. + if (!gotEnvUpdate) { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + + gotEnvUpdate = true + checkProgress() + } } override def onOtherEvent(event: SparkListenerEvent): Unit = event match { @@ -989,6 +1071,17 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } } + /** + * Throws a halt exception to stop replay if enough data to create the app listing has been + * read. + */ + private def checkProgress(): Unit = { + if (haltEnabled && !halted && app.id != null && gotEnvUpdate) { + halted = true + throw new HaltReplayException() + } + } + private class MutableApplicationInfo { var id: String = null var name: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index efdbf672bb52f..25ba9edb9e014 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -49,4 +49,19 @@ private[spark] object config { .intConf .createWithDefault(18080) + val FAST_IN_PROGRESS_PARSING = + ConfigBuilder("spark.history.fs.inProgressOptimization.enabled") + .doc("Enable optimized handling of in-progress logs. This option may leave finished " + + "applications that fail to rename their event logs listed as in-progress.") + .booleanConf + .createWithDefault(true) + + val END_EVENT_REPARSE_CHUNK_SIZE = + ConfigBuilder("spark.history.fs.endEventReparseChunkSize") + .doc("How many bytes to parse at the end of log files looking for the end event. " + + "This is used to speed up generation of application listings by skipping unnecessary " + + "parts of event log files. It can be disabled by setting this config to 0.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index c9cd662f5709d..226c23733c870 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -115,6 +115,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } } catch { + case e: HaltReplayException => + // Just stop replay. case _: EOFException if maybeTruncated => case ioe: IOException => throw ioe @@ -124,8 +126,17 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + override protected def isIgnorableException(e: Throwable): Boolean = { + e.isInstanceOf[HaltReplayException] + } + } +/** + * Exception that can be thrown by listeners to halt replay. This is handled by ReplayListenerBus + * only, and will cause errors if thrown when using other bus implementations. + */ +private[spark] class HaltReplayException extends RuntimeException private[spark] object ReplayListenerBus { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 76a56298aaebc..b25a731401f23 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -81,7 +81,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { try { doPostEvent(listener, event) } catch { - case NonFatal(e) => + case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { if (maybeTimerContext != null) { @@ -97,6 +97,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ protected def doPostEvent(listener: L, event: E): Unit + /** Allows bus implementations to prevent error logging for certain exceptions. */ + protected def isIgnorableException(e: Throwable): Boolean = false + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 0ba57bf4563c1..77b239489d489 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{doReturn, mock, spy, verify} +import org.mockito.Mockito.{mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -151,8 +151,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc var mergeApplicationListingCall = 0 override protected def mergeApplicationListing( fileStatus: FileStatus, - lastSeen: Long): Unit = { - super.mergeApplicationListing(fileStatus, lastSeen) + lastSeen: Long, + enableSkipToEnd: Boolean): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen, enableSkipToEnd) mergeApplicationListingCall += 1 } } @@ -256,14 +257,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => - list should not be (null) list.size should be (1) list.head.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt1, true, None, + writeFile(app2Attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -649,8 +649,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Add more info to the app log, and trigger the provider to update things. writeFile(appLog, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + SparkListenerJobStart(0, 1L, Nil, null) ) provider.checkForLogs() @@ -668,11 +667,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("clean up stale app information") { val storeDir = Utils.createTempDir() val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - val provider = spy(new FsHistoryProvider(conf)) + val clock = new ManualClock() + val provider = spy(new FsHistoryProvider(conf, clock)) val appId = "new1" // Write logs for two app attempts. - doReturn(1L).when(provider).getNewLastScanTime() + clock.advance(1) val attempt1 = newLogFile(appId, Some("1"), inProgress = false) writeFile(attempt1, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), @@ -697,7 +697,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since // attempt 2 still exists, listing data should be there. - doReturn(2L).when(provider).getNewLastScanTime() + clock.advance(1) attempt1.delete() updateAndCheck(provider) { list => assert(list.size === 1) @@ -708,7 +708,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(provider.getAppUI(appId, None) === None) // Delete the second attempt's log file. Now everything should go away. - doReturn(3L).when(provider).getNewLastScanTime() + clock.advance(1) attempt2.delete() updateAndCheck(provider) { list => assert(list.isEmpty) @@ -718,9 +718,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-21571: clean up removes invalid history files") { val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") - val provider = new FsHistoryProvider(conf, clock) { - override def getNewLastScanTime(): Long = clock.getTimeMillis() - } + val provider = new FsHistoryProvider(conf, clock) // Create 0-byte size inprogress and complete files var logCount = 0 @@ -772,6 +770,54 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(new File(testDir.toURI).listFiles().size === validLogCount) } + test("always find end event for finished apps") { + // Create a log file where the end event is before the configure chunk to be reparsed at + // the end of the file. The correct listing should still be generated. + val log = newLogFile("end-event-test", None, inProgress = false) + writeFile(log, true, None, + Seq( + SparkListenerApplicationStart("end-event-test", Some("end-event-test"), 1L, "test", None), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> Seq.empty, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(5L) + ) ++ (1 to 1000).map { i => SparkListenerJobStart(i, i, Nil) }: _*) + + val conf = createTestConf().set(END_EVENT_REPARSE_CHUNK_SIZE.key, s"1k") + val provider = new FsHistoryProvider(conf) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).attempts.size === 1) + assert(list(0).attempts(0).completed) + } + } + + test("parse event logs with optimizations off") { + val conf = createTestConf() + .set(END_EVENT_REPARSE_CHUNK_SIZE, 0L) + .set(FAST_IN_PROGRESS_PARSING, false) + val provider = new FsHistoryProvider(conf) + + val complete = newLogFile("complete", None, inProgress = false) + writeFile(complete, true, None, + SparkListenerApplicationStart("complete", Some("complete"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + val incomplete = newLogFile("incomplete", None, inProgress = true) + writeFile(incomplete, true, None, + SparkListenerApplicationStart("incomplete", Some("incomplete"), 1L, "test", None) + ) + + updateAndCheck(provider) { list => + list.size should be (2) + list.count(_.attempts.head.completed) should be (1) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -815,7 +861,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private def createTestConf(inMemory: Boolean = false): SparkConf = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set(EVENT_LOG_DIR, testDir.getAbsolutePath()) + .set(FAST_IN_PROGRESS_PARSING, true) if (!inMemory) { conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) @@ -848,4 +895,3 @@ class TestGroupsMappingProvider extends GroupMappingServiceProvider { mappings.get(username).map(Set(_)).getOrElse(Set.empty) } } - From 3cb82047f2f51af553df09b9323796af507d36f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 10:13:44 -0500 Subject: [PATCH 294/593] [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. The current in-process launcher implementation just calls the SparkSubmit object, which, in case of errors, will more often than not exit the JVM. This is not desirable since this launcher is meant to be used inside other applications, and that would kill the application. The change turns SparkSubmit into a class, and abstracts aways some of the functionality used to print error messages and abort the submission process. The default implementation uses the logging system for messages, and throws exceptions for errors. As part of that I also moved some code that doesn't really belong in SparkSubmit to a better location. The command line invocation of spark-submit now uses a special implementation of the SparkSubmit class that overrides those behaviors to do what is expected from the command line version (print to the terminal, exit the JVM, etc). A lot of the changes are to replace calls to methods such as "printErrorAndExit" with the new API. As part of adding tests for this, I had to fix some small things in the launcher option parser so that things like "--version" can work when used in the launcher library. There is still code that prints directly to the terminal, like all the Ivy-related code in SparkSubmitUtils, and other areas where some re-factoring would help, like the CommandLineUtils class, but I chose to leave those alone to keep this change more focused. Aside from existing and added unit tests, I ran command line tools with a bunch of different arguments to make sure messages and errors behave like before. Author: Marcelo Vanzin Closes #20925 from vanzin/SPARK-22941. --- .../apache/spark/deploy/DependencyUtils.scala | 30 +- .../org/apache/spark/deploy/SparkSubmit.scala | 318 +++++++++--------- .../spark/deploy/SparkSubmitArguments.scala | 90 +++-- .../spark/deploy/worker/DriverWrapper.scala | 4 +- .../apache/spark/util/CommandLineUtils.scala | 18 +- .../spark/launcher/SparkLauncherSuite.java | 37 +- .../spark/deploy/SparkSubmitSuite.scala | 69 ++-- .../rest/StandaloneRestSubmitSuite.scala | 2 +- .../spark/launcher/AbstractLauncher.java | 6 +- .../spark/launcher/InProcessLauncher.java | 14 +- .../launcher/SparkSubmitCommandBuilder.java | 82 +++-- project/MimaExcludes.scala | 7 +- .../deploy/mesos/MesosClusterDispatcher.scala | 10 +- .../MesosClusterDispatcherArguments.scala | 6 +- 14 files changed, 401 insertions(+), 292 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index fac834a70b893..178bdcfccb603 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.{MutableURLClassLoader, Utils} -private[deploy] object DependencyUtils { +private[deploy] object DependencyUtils extends Logging { def resolveMavenDependencies( packagesExclusions: String, @@ -75,7 +76,7 @@ private[deploy] object DependencyUtils { def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { if (jars != null) { for (jar <- jars.split(",")) { - SparkSubmit.addJarToClasspath(jar, loader) + addJarToClasspath(jar, loader) } } } @@ -151,6 +152,31 @@ private[deploy] object DependencyUtils { }.mkString(",") } + def addJarToClasspath(localJar: String, loader: MutableURLClassLoader): Unit = { + val uri = Utils.resolveURI(localJar) + uri.getScheme match { + case "file" | "local" => + val file = new File(uri.getPath) + if (file.exists()) { + loader.addURL(file.toURI.toURL) + } else { + logWarning(s"Local jar $file does not exist, skipping.") + } + case _ => + logWarning(s"Skip remote jar $uri.") + } + } + + /** + * Merge a sequence of comma-separated file lists, some of which may be null to indicate + * no files, into a single comma-separated string. + */ + def mergeFileLists(lists: String*): String = { + val merged = lists.filterNot(StringUtils.isBlank) + .flatMap(Utils.stringToSeq) + if (merged.nonEmpty) merged.mkString(",") else null + } + private def splitOnFragment(path: String): (URI, Option[String]) = { val uri = Utils.resolveURI(path) val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index eddbedeb1024d..427c797755b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ */ private[deploy] object SparkSubmitAction extends Enumeration { type SparkSubmitAction = Value - val SUBMIT, KILL, REQUEST_STATUS = Value + val SUBMIT, KILL, REQUEST_STATUS, PRINT_VERSION = Value } /** @@ -67,78 +67,32 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit extends CommandLineUtils with Logging { +private[spark] class SparkSubmit extends Logging { import DependencyUtils._ + import SparkSubmit._ - // Cluster managers - private val YARN = 1 - private val STANDALONE = 2 - private val MESOS = 4 - private val LOCAL = 8 - private val KUBERNETES = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES - - // Deploy modes - private val CLIENT = 1 - private val CLUSTER = 2 - private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - - // Special primary resource names that represent shells rather than application jars. - private val SPARK_SHELL = "spark-shell" - private val PYSPARK_SHELL = "pyspark-shell" - private val SPARKR_SHELL = "sparkr-shell" - private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" - private val R_PACKAGE_ARCHIVE = "rpkg.zip" - - private val CLASS_NOT_FOUND_EXIT_STATUS = 101 - - // Following constants are visible for testing. - private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.yarn.YarnClusterApplication" - private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() - private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() - private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" - - // scalastyle:off println - private[spark] def printVersionAndExit(): Unit = { - printStream.println("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - printStream.println("Using Scala %s, %s, %s".format( - Properties.versionString, Properties.javaVmName, Properties.javaVersion)) - printStream.println("Branch %s".format(SPARK_BRANCH)) - printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) - printStream.println("Revision %s".format(SPARK_REVISION)) - printStream.println("Url %s".format(SPARK_REPO_URL)) - printStream.println("Type --help for more information.") - exitFn(0) - } - // scalastyle:on println - - override def main(args: Array[String]): Unit = { + def doSubmit(args: Array[String]): Unit = { // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to // be reset before the application starts. val uninitLog = initializeLogIfNecessary(true, silent = true) - val appArgs = new SparkSubmitArguments(args) + val appArgs = parseArguments(args) if (appArgs.verbose) { - // scalastyle:off println - printStream.println(appArgs) - // scalastyle:on println + logInfo(appArgs.toString) } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.PRINT_VERSION => printVersion() } } + protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) + } + /** * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ @@ -156,6 +110,24 @@ object SparkSubmit extends CommandLineUtils with Logging { .requestSubmissionStatus(args.submissionToRequestStatusFor) } + /** Print version information to the log. */ + private def printVersion(): Unit = { + logInfo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + logInfo("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) + logInfo(s"Branch $SPARK_BRANCH") + logInfo(s"Compiled by user $SPARK_BUILD_USER on $SPARK_BUILD_DATE") + logInfo(s"Revision $SPARK_REVISION") + logInfo(s"Url $SPARK_REPO_URL") + logInfo("Type --help for more information.") + } + /** * Submit the application using the provided parameters. * @@ -185,10 +157,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { - // scalastyle:off println - printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - // scalastyle:on println - exitFn(1) + error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") } else { throw e } @@ -210,14 +179,11 @@ object SparkSubmit extends CommandLineUtils with Logging { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { - // scalastyle:off println - printStream.println("Running Spark using the REST application submission protocol.") - // scalastyle:on println - doRunMain() + logInfo("Running Spark using the REST application submission protocol.") } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => - printWarning(s"Master endpoint ${args.master} was not a REST server. " + + logWarning(s"Master endpoint ${args.master} was not a REST server. " + "Falling back to legacy submission gateway instead.") args.useRest = false submit(args, false) @@ -245,19 +211,6 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { - try { - doPrepareSubmitEnvironment(args, conf) - } catch { - case e: SparkException => - printErrorAndExit(e.getMessage) - throw e - } - } - - private def doPrepareSubmitEnvironment( - args: SparkSubmitArguments, - conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() @@ -268,7 +221,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val clusterManager: Int = args.master match { case "yarn" => YARN case "yarn-client" | "yarn-cluster" => - printWarning(s"Master ${args.master} is deprecated since 2.0." + + logWarning(s"Master ${args.master} is deprecated since 2.0." + " Please use master \"yarn\" with specified deploy mode instead.") YARN case m if m.startsWith("spark") => STANDALONE @@ -276,7 +229,7 @@ object SparkSubmit extends CommandLineUtils with Logging { case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") + error("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -284,7 +237,9 @@ object SparkSubmit extends CommandLineUtils with Logging { var deployMode: Int = args.deployMode match { case "client" | null => CLIENT case "cluster" => CLUSTER - case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 + case _ => + error("Deploy mode must be either client or cluster") + -1 } // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both @@ -296,16 +251,16 @@ object SparkSubmit extends CommandLineUtils with Logging { deployMode = CLUSTER args.master = "yarn" case ("yarn-cluster", "client") => - printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + error("Client deploy mode is not compatible with master \"yarn-cluster\"") case ("yarn-client", "cluster") => - printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + error("Cluster deploy mode is not compatible with master \"yarn-client\"") case (_, mode) => args.master = "yarn" } // Make sure YARN is included in our build if we're trying to use it if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load YARN classes. " + "This copy of Spark may not have been compiled with YARN support.") } @@ -315,7 +270,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.master = Utils.checkAndGetK8sMasterUrl(args.master) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load KUBERNETES classes. " + "This copy of Spark may not have been compiled with KUBERNETES support.") } @@ -324,23 +279,23 @@ object SparkSubmit extends CommandLineUtils with Logging { // Fail fast, the following modes are not supported or applicable (clusterManager, deployMode) match { case (STANDALONE, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + error("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") case (STANDALONE, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + + error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") case (KUBERNETES, _) if args.isPython => - printErrorAndExit("Python applications are currently not supported for Kubernetes.") + error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => - printErrorAndExit("R applications are currently not supported for Kubernetes.") + error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => - printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") + error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + error("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + error("Cluster deploy mode is not applicable to Spark SQL shell.") case (_, CLUSTER) if isThriftServer(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") + error("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } @@ -493,11 +448,11 @@ object SparkSubmit extends CommandLineUtils with Logging { if (args.isR && clusterManager == YARN) { val sparkRPackagePath = RUtils.localSparkRPackagePath if (sparkRPackagePath.isEmpty) { - printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + error("SPARK_HOME does not exist for R application in YARN mode.") } val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) if (!sparkRPackageFile.exists()) { - printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + error(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString @@ -510,7 +465,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val rPackageFile = RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) if (!rPackageFile.exists()) { - printErrorAndExit("Failed to zip all the built R packages.") + error("Failed to zip all the built R packages.") } val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString @@ -521,12 +476,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // TODO: Support distributing R packages with standalone cluster if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + error("Distributing R packages with standalone cluster is not supported.") } // TODO: Support distributing R packages with mesos cluster if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with mesos cluster is not supported.") + error("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -799,9 +754,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" - // scalastyle:off println - printStream.println(s"Setting ${key} to ${shortUserName}") - // scalastyle:off println + logInfo(s"Setting ${key} to ${shortUserName}") sparkConf.set(key, shortUserName) } @@ -817,16 +770,14 @@ object SparkSubmit extends CommandLineUtils with Logging { sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { - // scalastyle:off println if (verbose) { - printStream.println(s"Main class:\n$childMainClass") - printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") + logInfo(s"Main class:\n$childMainClass") + logInfo(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") - printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") - printStream.println("\n") + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") + logInfo("\n") } - // scalastyle:on println val loader = if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { @@ -848,23 +799,19 @@ object SparkSubmit extends CommandLineUtils with Logging { mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass.", e) if (childMainClass.contains("thriftserver")) { - // scalastyle:off println - printStream.println(s"Failed to load main class $childMainClass.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load main class $childMainClass.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) case e: NoClassDefFoundError => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass: ${e.getMessage()}") if (e.getMessage.contains("org/apache/hadoop/hive")) { - // scalastyle:off println - printStream.println(s"Failed to load hive class.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load hive class.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { @@ -872,7 +819,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + logWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") } new JavaMainApplication(mainClass) } @@ -891,29 +838,90 @@ object SparkSubmit extends CommandLineUtils with Logging { app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => - findCause(t) match { - case SparkUserAppException(exitCode) => - System.exit(exitCode) - - case t: Throwable => - throw t - } + throw findCause(t) } } - private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { - val uri = Utils.resolveURI(localJar) - uri.getScheme match { - case "file" | "local" => - val file = new File(uri.getPath) - if (file.exists()) { - loader.addURL(file.toURI.toURL) - } else { - printWarning(s"Local jar $file does not exist, skipping.") + /** Throw a SparkException with the given error message. */ + private def error(msg: String): Unit = throw new SparkException(msg) + +} + + +/** + * This entry point is used by the launcher library to start in-process Spark applications. + */ +private[spark] object InProcessSparkSubmit { + + def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() + submit.doSubmit(args) + } + +} + +object SparkSubmit extends CommandLineUtils with Logging { + + // Cluster managers + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES + + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + + // Special primary resource names that represent shells rather than application jars. + private val SPARK_SHELL = "spark-shell" + private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" + + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" + + override def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() { + self => + + override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) { + override protected def logInfo(msg: => String): Unit = self.logInfo(msg) + + override protected def logWarning(msg: => String): Unit = self.logWarning(msg) } - case _ => - printWarning(s"Skip remote jar $uri.") + } + + override protected def logInfo(msg: => String): Unit = printMessage(msg) + + override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + + override def doSubmit(args: Array[String]): Unit = { + try { + super.doSubmit(args) + } catch { + case e: SparkUserAppException => + exitFn(e.exitCode) + case e: SparkException => + printErrorAndExit(e.getMessage()) + } + } + } + + submit.doSubmit(args) } /** @@ -962,17 +970,6 @@ object SparkSubmit extends CommandLineUtils with Logging { res == SparkLauncher.NO_RESOURCE } - /** - * Merge a sequence of comma-separated file lists, some of which may be null to indicate - * no files, into a single comma-separated string. - */ - private[deploy] def mergeFileLists(lists: String*): String = { - val merged = lists.filterNot(StringUtils.isBlank) - .flatMap(_.split(",")) - .mkString(",") - if (merged == "") null else merged - } - } /** Provides utility functions to be used inside SparkSubmit. */ @@ -1000,12 +997,12 @@ private[spark] object SparkSubmitUtils { override def toString: String = s"$groupId:$artifactId:$version" } -/** - * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided - * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. - * @param coordinates Comma-delimited string of maven coordinates - * @return Sequence of Maven coordinates - */ + /** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { coordinates.split(",").map { p => val splits = p.replace("/", ":").split(":") @@ -1304,6 +1301,13 @@ private[spark] object SparkSubmitUtils { rule } + def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => throw new SparkException(s"Spark config without '=': $pair") + } + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 8e7070593687b..0733fdb72cafb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -29,7 +29,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source import scala.util.Try +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -40,7 +42,7 @@ import org.apache.spark.util.Utils * The env argument is used for testing. */ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) - extends SparkSubmitArgumentsParser { + extends SparkSubmitArgumentsParser with Logging { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -85,8 +87,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() - // scalastyle:off println - if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") + if (verbose) { + logInfo(s"Using properties file: $propertiesFile") + } Option(propertiesFile).foreach { filename => val properties = Utils.getPropertiesFromFile(filename) properties.foreach { case (k, v) => @@ -95,21 +98,16 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Property files may contain sensitive information, so redact before printing if (verbose) { Utils.redact(properties).foreach { case (k, v) => - SparkSubmit.printStream.println(s"Adding default property: $k=$v") + logInfo(s"Adding default property: $k=$v") } } } - // scalastyle:on println defaultProperties } // Set parameters from command line arguments - try { - parse(args.asJava) - } catch { - case e: IllegalArgumentException => - SparkSubmit.printErrorAndExit(e.getMessage()) - } + parse(args.asJava) + // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() // Remove keys that don't start with "spark." from `sparkProperties`. @@ -141,7 +139,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S sparkProperties.foreach { case (k, v) => if (!k.startsWith("spark.")) { sparkProperties -= k - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + logWarning(s"Ignoring non-spark config property: $k=$v") } } } @@ -215,10 +213,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } catch { case _: Exception => - SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") + error(s"Cannot load main class from JAR $primaryResource") } case _ => - SparkSubmit.printErrorAndExit( + error( s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " + "Please specify a class through --class.") } @@ -248,6 +246,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case SUBMIT => validateSubmitArguments() case KILL => validateKillArguments() case REQUEST_STATUS => validateStatusRequestArguments() + case PRINT_VERSION => } } @@ -256,62 +255,61 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") + error("Must specify a primary resource (JAR or Python or R file)") } if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { - SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") + error("No main class set in JAR; please specify one with --class") } if (driverMemory != null && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + error("Driver memory must be a positive number") } if (executorMemory != null && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + error("Executor memory must be a positive number") } if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + error("Executor cores must be a positive number") } if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + error("Total executor cores must be a positive number") } if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { - SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") + error("--py-files given but primary resource is not a Python script") } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { - throw new Exception(s"When running with master '$master' " + + error(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } if (proxyUser != null && principal != null) { - SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + error("Only one of --proxy-user or --principal can be provided.") } } private def validateKillArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( - "Killing submissions is only supported in standalone or Mesos mode!") + error("Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + error("Please specify a submission to kill.") } } private def validateStatusRequestArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( + error( "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + error("Please specify a submission to request status for.") } } @@ -368,7 +366,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case DEPLOY_MODE => if (value != "client" && value != "cluster") { - SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") + error("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value @@ -405,14 +403,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case KILL_SUBMISSION => submissionToKill = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + error(s"Action cannot be both $action and $KILL.") } action = KILL case STATUS => submissionToRequestStatusFor = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + error(s"Action cannot be both $action and $REQUEST_STATUS.") } action = REQUEST_STATUS @@ -444,7 +442,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + val (confName, confValue) = SparkSubmitUtils.parseSparkConfProperty(value) sparkProperties(confName) = confValue case PROXY_USER => @@ -463,15 +461,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S verbose = true case VERSION => - SparkSubmit.printVersionAndExit() + action = SparkSubmitAction.PRINT_VERSION case USAGE_ERROR => printUsageAndExit(1) case _ => - throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + error(s"Unexpected argument '$opt'.") } - true + action != SparkSubmitAction.PRINT_VERSION } /** @@ -482,7 +480,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S */ override protected def handleUnknown(opt: String): Boolean = { if (opt.startsWith("-")) { - SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") + error(s"Unrecognized option '$opt'.") } primaryResource = @@ -501,20 +499,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { - // scalastyle:off println - val outStream = SparkSubmit.printStream if (unknownParam != null) { - outStream.println("Unknown/unsupported param " + unknownParam) + logInfo("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) - outStream.println(command) + logInfo(command) val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB - outStream.println( + logInfo( s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, @@ -596,12 +592,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S ) if (SparkSubmit.isSqlShell(mainClass)) { - outStream.println("CLI options:") - outStream.println(getSqlShellOptions()) + logInfo("CLI options:") + logInfo(getSqlShellOptions()) } - // scalastyle:on println - SparkSubmit.exitFn(exitCode) + throw new SparkUserAppException(exitCode) } /** @@ -655,4 +650,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } + + private def error(msg: String): Unit = throw new SparkException(msg) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 3f71237164a15..8d6a2b80ef5f2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -25,7 +25,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -93,7 +93,7 @@ object DriverWrapper extends Logging { val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { - SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + DependencyUtils.mergeFileLists(jarsProp, resolvedMavenCoordinates) } else { jarsProp } diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala index d73901686b705..4b6602b50aa1c 100644 --- a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -33,24 +33,14 @@ private[spark] trait CommandLineUtils { private[spark] var printStream: PrintStream = System.err // scalastyle:off println - - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printMessage(str: String): Unit = printStream.println(str) + // scalastyle:on println private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") + printMessage("Error: " + str) + printMessage("Run with --help for usage help or --verbose for debug output") exitFn(1) } - // scalastyle:on println - - private[spark] def parseSparkConfProperty(pair: String): (String, String) = { - pair.split("=", 2).toSeq match { - case Seq(k, v) => (k, v) - case _ => printErrorAndExit(s"Spark config without '=': $pair") - throw new SparkException(s"Spark config without '=': $pair") - } - } - def main(args: Array[String]): Unit } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 2225591a4ff75..6a1a38c1a54f4 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -109,7 +109,7 @@ public void testChildProcLauncher() throws Exception { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.appender=childproc") + "-Dfoo=bar -Dtest.appender=console") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) @@ -192,6 +192,41 @@ private void inProcessLauncherTestImpl() throws Exception { } } + @Test + public void testInProcessLauncherDoesNotKillJvm() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + List wrongArgs = Arrays.asList( + new String[] { "--unknown" }, + new String[] { opts.DEPLOY_MODE, "invalid" }); + + for (String[] args : wrongArgs) { + InProcessLauncher launcher = new InProcessLauncher() + .setAppResource(SparkLauncher.NO_RESOURCE); + switch (args.length) { + case 2: + launcher.addSparkArg(args[0], args[1]); + break; + + case 1: + launcher.addSparkArg(args[0]); + break; + + default: + fail("FIXME: invalid test."); + } + + SparkAppHandle handle = launcher.startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.FAILED, handle.getState()); + } + + // Run --version, which is useless as a use case, but should succeed and not exit the JVM. + // The expected state is "LOST" since "--version" doesn't report state back to the handle. + SparkAppHandle handle = new InProcessLauncher().addSparkArg(opts.VERSION).startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.LOST, handle.getState()); + } + public static class SparkLauncherTestApp { public static void main(String[] args) throws Exception { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 0d7c342a5eacd..7451e07b25a1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -109,6 +110,8 @@ class SparkSubmitSuite private val emptyIvySettings = File.createTempFile("ivy", ".xml") FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + private val submit = new SparkSubmit() + override def beforeEach() { super.beforeEach() } @@ -128,13 +131,16 @@ class SparkSubmitSuite } test("handle binary specified but not class") { - testPrematureExit(Array("foo.jar"), "No main class") + val jar = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + testPrematureExit(Array(jar.toString()), "No main class") } test("handles arguments with --key=val") { val clArgs = Seq( "--jars=one.jar,two.jar,three.jar", - "--name=myApp") + "--name=myApp", + "--class=org.FooBar", + SparkLauncher.NO_RESOURCE) val appArgs = new SparkSubmitArguments(clArgs) appArgs.jars should include regex (".*one.jar,.*two.jar,.*three.jar") appArgs.name should be ("myApp") @@ -182,7 +188,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") conf.get("spark.submit.deployMode") should be ("client") @@ -192,11 +198,11 @@ class SparkSubmitSuite "--master", "yarn", "--deploy-mode", "cluster", "--conf", "spark.submit.deployMode=client", - "-class", "org.SomeClass", + "--class", "org.SomeClass", "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") conf1.get("spark.submit.deployMode") should be ("cluster") @@ -210,7 +216,7 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") conf2.get("spark.submit.deployMode") should be ("client") } @@ -233,7 +239,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -276,7 +282,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -322,7 +328,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -359,7 +365,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -381,7 +387,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -403,7 +409,7 @@ class SparkSubmitSuite "/home/thejar.jar", "arg1") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar")) @@ -428,7 +434,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.executor.memory") should be ("5g") conf.get("spark.master") should be ("yarn") conf.get("spark.submit.deployMode") should be ("cluster") @@ -441,12 +447,12 @@ class SparkSubmitSuite val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } @@ -625,7 +631,7 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -640,7 +646,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -656,7 +662,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -708,7 +714,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) conf.get("spark.files") should be(Utils.resolveURIs(files)) @@ -725,7 +731,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -740,7 +746,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -757,7 +763,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -778,17 +784,17 @@ class SparkSubmitSuite } test("SPARK_CONF_DIR overrides spark-defaults.conf") { - forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + forConfDir(Map("spark.executor.memory" -> "3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", unusedJar.toString) - val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + val appArgs = new SparkSubmitArguments(args, env = Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("3g") } } @@ -809,6 +815,9 @@ class SparkSubmitSuite val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + val tempPyFile = File.createTempFile("tmpApp", ".py") + tempPyFile.deleteOnExit() + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", @@ -818,10 +827,10 @@ class SparkSubmitSuite "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", - jar2.toString) + tempPyFile.toURI().toString()) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) conf.get("spark.yarn.dist.files").split(",").toSet should be @@ -947,7 +956,7 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. conf.get("spark.yarn.dist.jars") should be (tmpJarPath) @@ -1007,7 +1016,7 @@ class SparkSubmitSuite ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) val jars = conf.get("spark.yarn.dist.jars").split(",").toSet @@ -1058,7 +1067,7 @@ class SparkSubmitSuite "hello") val exception = intercept[SparkException] { - SparkSubmit.main(args) + submit.doSubmit(args) } assert(exception.getMessage() === "hello") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index e505bc018857d..54c168a8218f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,7 +445,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 44e69fc45dffa..4e02843480e8f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -139,7 +139,7 @@ public T setMainClass(String mainClass) { public T addSparkArg(String arg) { SparkSubmitOptionParser validator = new ArgumentValidator(false); validator.parse(Arrays.asList(arg)); - builder.sparkArgs.add(arg); + builder.userArgs.add(arg); return self(); } @@ -187,8 +187,8 @@ public T addSparkArg(String name, String value) { } } else { validator.parse(Arrays.asList(name, value)); - builder.sparkArgs.add(name); - builder.sparkArgs.add(value); + builder.userArgs.add(name); + builder.userArgs.add(value); } return self(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java index 6d726b4a69a86..688e1f763c205 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java @@ -89,10 +89,18 @@ Method findSparkSubmit() throws IOException { } Class sparkSubmit; + // SPARK-22941: first try the new SparkSubmit interface that has better error handling, + // but fall back to the old interface in case someone is mixing & matching launcher and + // Spark versions. try { - sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); - } catch (Exception e) { - throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", e); + sparkSubmit = cl.loadClass("org.apache.spark.deploy.InProcessSparkSubmit"); + } catch (Exception e1) { + try { + sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); + } catch (Exception e2) { + throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", + e2); + } } Method main; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index e0ef22d7d5058..5cb6457bf5c21 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -88,8 +88,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkLauncher.NO_RESOURCE); } - final List sparkArgs; - private final boolean isAppResourceReq; + final List userArgs; + private final List parsedArgs; + private final boolean requiresAppResource; private final boolean isExample; /** @@ -99,17 +100,27 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ private boolean allowsMixedArguments; + /** + * This constructor is used when creating a user-configurable launcher. It allows the + * spark-submit argument list to be modified after creation. + */ SparkSubmitCommandBuilder() { - this.sparkArgs = new ArrayList<>(); - this.isAppResourceReq = true; + this.requiresAppResource = true; this.isExample = false; + this.parsedArgs = new ArrayList<>(); + this.userArgs = new ArrayList<>(); } + /** + * This constructor is used when invoking spark-submit; it parses and validates arguments + * provided by the user on the command line. + */ SparkSubmitCommandBuilder(List args) { this.allowsMixedArguments = false; - this.sparkArgs = new ArrayList<>(); + this.parsedArgs = new ArrayList<>(); boolean isExample = false; List submitArgs = args; + this.userArgs = Collections.emptyList(); if (args.size() > 0) { switch (args.get(0)) { @@ -131,21 +142,21 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } this.isExample = isExample; - OptionParser parser = new OptionParser(); + OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.isAppResourceReq = parser.isAppResourceReq; - } else { + this.requiresAppResource = parser.requiresAppResource; + } else { this.isExample = isExample; - this.isAppResourceReq = false; + this.requiresAppResource = false; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) { + if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) { + } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -154,9 +165,19 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); - SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + OptionParser parser = new OptionParser(false); + final boolean requiresAppResource; + + // If the user args array is not empty, we need to parse it to detect exactly what + // the user is trying to run, so that checks below are correct. + if (!userArgs.isEmpty()) { + parser.parse(userArgs); + requiresAppResource = parser.requiresAppResource; + } else { + requiresAppResource = this.requiresAppResource; + } - if (!allowsMixedArguments && isAppResourceReq) { + if (!allowsMixedArguments && requiresAppResource) { checkArgument(appResource != null, "Missing application resource."); } @@ -208,15 +229,16 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isAppResourceReq) { - checkArgument(!isExample || mainClass != null, "Missing example class name."); + if (isExample) { + checkArgument(mainClass != null, "Missing example class name."); } + if (mainClass != null) { args.add(parser.CLASS); args.add(mainClass); } - args.addAll(sparkArgs); + args.addAll(parsedArgs); if (appResource != null) { args.add(appResource); } @@ -399,7 +421,12 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean isAppResourceReq = true; + boolean requiresAppResource = true; + private final boolean errorOnUnknownArgs; + + OptionParser(boolean errorOnUnknownArgs) { + this.errorOnUnknownArgs = errorOnUnknownArgs; + } @Override protected boolean handle(String opt, String value) { @@ -443,23 +470,23 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - isAppResourceReq = false; - sparkArgs.add(opt); - sparkArgs.add(value); + requiresAppResource = false; + parsedArgs.add(opt); + parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; case VERSION: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; default: - sparkArgs.add(opt); + parsedArgs.add(opt); if (value != null) { - sparkArgs.add(value); + parsedArgs.add(value); } break; } @@ -483,12 +510,13 @@ protected boolean handleUnknown(String opt) { mainClass = className; appResource = SparkLauncher.NO_RESOURCE; return false; - } else { + } else if (errorOnUnknownArgs) { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); checkState(appResource == null, "Found unrecognized argument but resource is already set."); appResource = opt; return false; } + return true; } @Override diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b37b4d51775e8..a87fa68422c34 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,12 +36,17 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), - + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index aa378c9d340f1..ccf33e8d4283c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer @@ -100,7 +100,13 @@ private[mesos] object MesosClusterDispatcher Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf - val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + val dispatcherArgs = try { + new MesosClusterDispatcherArguments(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage()) + null + } conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 096bb4e1af688..267a4283db9e6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkSubmitUtils import org.apache.spark.util.{IntParam, Utils} private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { @@ -95,9 +96,8 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: parse(tail) case ("--conf") :: value :: tail => - val pair = MesosClusterDispatcher. - parseSparkConfProperty(value) - confProperties(pair._1) = pair._2 + val (k, v) = SparkSubmitUtils.parseSparkConfProperty(value) + confProperties(k) = v parse(tail) case ("--help") :: tail => From 75a183071c4ed2e407c930edfdf721779662b3ee Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 11 Apr 2018 09:59:38 -0700 Subject: [PATCH 295/593] [SPARK-22883] ML test for StructuredStreaming: spark.ml.feature, I-M ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * IDF * Imputer * Interaction * MaxAbsScaler * MinHashLSH * MinMaxScaler * NGram ## How was this patch tested? It is a bunch of tests! Author: Joseph K. Bradley Closes #20964 from jkbradley/SPARK-22883-part2. --- .../apache/spark/ml/feature/IDFSuite.scala | 14 +++--- .../spark/ml/feature/ImputerSuite.scala | 31 ++++++++++--- .../spark/ml/feature/InteractionSuite.scala | 46 ++++++++++--------- .../spark/ml/feature/MaxAbsScalerSuite.scala | 14 +++--- .../spark/ml/feature/MinHashLSHSuite.scala | 25 ++++++++-- .../spark/ml/feature/MinMaxScalerSuite.scala | 14 +++--- .../apache/spark/ml/feature/NGramSuite.scala | 2 +- 7 files changed, 89 insertions(+), 57 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 005edf73d29be..cdd62be43b54c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class IDFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,7 +55,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((numOfData + 1.0) / (x + 1.0)) }) @@ -72,7 +70,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead MLTestingUtils.checkCopyAndUids(idfEst, idfModel) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } @@ -85,7 +83,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 }) @@ -99,7 +97,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setMinDocFreq(1) .fit(df) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index c08b35b419266..75f63a623e6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -16,13 +16,12 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.SparkException +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { val df = spark.createDataFrame( Seq( @@ -76,6 +75,28 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default ImputerSuite.iterateStrategyTest(imputer, df) } + test("Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCols(Array("value")) + .setOutputCols(Array("out")) + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -164,8 +185,6 @@ object ImputerSuite { * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { - val inputCols = imputer.getInputCols - Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 54f059e5f143e..eea31fc7ae3f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class InteractionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -63,9 +63,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("numeric interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -73,14 +73,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -92,9 +93,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("nominal interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -103,14 +104,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def val df = data.select( col("a").as( "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index 918da4f9388d4..8dd0f0cb91e37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -14,15 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -45,9 +44,10 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setOutputCol("scaled") val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(expectedVec: Vector, actualVec: Vector) => + assert(expectedVec === actualVec, + s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec") } MLTestingUtils.checkCopyAndUids(scaler, model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3da0fb7da01ae..1c2956cb82908 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{Dataset, Row} -class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class MinHashLSHSuite extends MLTest with DefaultReadWriteTest { @transient var dataset: Dataset[_] = _ @@ -175,4 +174,20 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(precision == 1.0) assert(recall >= 0.7) } + + test("MinHashLSHModel.transform should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + model.set(model.inputCol, "keys") + testTransformer[Tuple1[Vector]](dataset.toDF(), model, "keys", model.getOutputCol) { + case Row(_: Vector, output: Seq[_]) => + assert(output.length === model.randCoefficients.length) + // no AND-amplification yet: SPARK-18450, so each hash output is of length 1 + output.foreach { + case hashOutput: Vector => assert(hashOutput.size === 1) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 51db74eb739ca..2d965f2ca2c54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -48,9 +46,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setMax(5) val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 === vector2, "Transformed vector is different with expected.") } MLTestingUtils.checkCopyAndUids(scaler, model) @@ -114,7 +112,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De val model = scaler.fit(df) model.transform(df).select("expected", "scaled").collect() .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + assert(vector1 === vector2, "Transformed vector is different with expected.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index e5956ee9942aa..201a335e0d7be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -84,7 +84,7 @@ class NGramSuite extends MLTest with DefaultReadWriteTest { def testNGram(t: NGram, dataFrame: DataFrame): Unit = { testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { - case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => + case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) => assert(actualNGrams === wantedNGrams) } } From 9d960de0814a1128318676cc2e91f447cdf0137f Mon Sep 17 00:00:00 2001 From: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Date: Wed, 11 Apr 2018 15:52:13 -0700 Subject: [PATCH 296/593] typo rawPredicition changed to rawPrediction MultilayerPerceptronClassifier had 4 occurrences ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Closes #21030 from JBauerKogentix/patch-1. --- python/pyspark/ml/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fbbe3d0307c81..ec17653a1adf9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1543,12 +1543,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction") + rawPredictionCol="rawPrediction") """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -1562,12 +1562,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): Sets params for MultilayerPerceptronClassifier. """ kwargs = self._input_kwargs From e904dfaf0d16f9fa0cc4d2f46a3dec1b1d77de75 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 11 Apr 2018 17:04:34 -0700 Subject: [PATCH 297/593] Revert "[SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient" This reverts commit 271c891b91917d660d1f6b995de397c47c7a6058. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 965950ed94fe8..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. - @transient private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used for aggregation without keys. + private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,8 +238,6 @@ case class HashAggregateExec( | } """.stripMargin) - bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner - val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 6a2289ecf020a99cd9b3bcea7da5e78fb4e0303a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 12 Apr 2018 15:58:04 +0800 Subject: [PATCH 298/593] [SPARK-23962][SQL][TEST] Fix race in currentExecutionIds(). SQLMetricsTestUtils.currentExecutionIds() was racing with the listener bus, which lead to some flaky tests. We should wait till the listener bus is empty. I tested by adding some Thread.sleep()s in SQLAppStatusListener, which reproduced the exceptions I saw on Jenkins. With this change, they went away. Author: Imran Rashid Closes #21041 from squito/SPARK-23962. --- .../apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 534d8bb629b8c..dcc540fc4f109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -34,6 +34,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ protected def currentExecutionIds(): Set[Long] = { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) statusStore.executionsList.map(_.executionId).toSet } From 0b19122d434e39eb117ccc3174a0688c9c874d48 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 22:21:30 +0800 Subject: [PATCH 299/593] [SPARK-23762][SQL] UTF8StringBuffer uses MemoryBlock ## What changes were proposed in this pull request? This PR tries to use `MemoryBlock` in `UTF8StringBuffer`. In general, there are two advantages to use `MemoryBlock`. 1. Has clean API calls rather than using a Java array or `PlatformMemory` 2. Improve runtime performance of memory access instead of using `Object`. ## How was this patch tested? Added `UTF8StringBufferSuite` Author: Kazuaki Ishizaki Closes #20871 from kiszk/SPARK-23762. --- .../codegen/UTF8StringBuilder.java | 35 +++++++--------- .../codegen/UTF8StringBuilderSuite.scala | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index f0f66bae245fd..f8000d78cd1b6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,6 +19,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -29,43 +31,34 @@ public class UTF8StringBuilder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - private byte[] buffer; - private int cursor = Platform.BYTE_ARRAY_OFFSET; + private ByteArrayMemoryBlock buffer; + private int length = 0; public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new byte[16]; + this.buffer = new ByteArrayMemoryBlock(16); } // Grows the buffer by at least `neededSize` private void grow(int neededSize) { - if (neededSize > ARRAY_MAX - totalSize()) { + if (neededSize > ARRAY_MAX - length) { throw new UnsupportedOperationException( "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } - final int length = totalSize() + neededSize; - if (buffer.length < length) { - int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); + final int requestedSize = length + neededSize; + if (buffer.size() < requestedSize) { + int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; + final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); + MemoryBlock.copyMemory(buffer, tmp, length); buffer = tmp; } } - private int totalSize() { - return cursor - Platform.BYTE_ARRAY_OFFSET; - } - public void append(UTF8String value) { grow(value.numBytes()); - value.writeToMemory(buffer, cursor); - cursor += value.numBytes(); + value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); + length += value.numBytes(); } public void append(String value) { @@ -73,6 +66,6 @@ public void append(String value) { } public UTF8String build() { - return UTF8String.fromBytes(buffer, 0, totalSize()); + return UTF8String.fromBytes(buffer.getByteArray(), 0, length); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala new file mode 100644 index 0000000000000..1b25a4b191f86 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class UTF8StringBuilderSuite extends SparkFunSuite { + + test("basic test") { + val sb = new UTF8StringBuilder() + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("") + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("abcd") + assert(sb.build() === UTF8String.fromString("abcd")) + + sb.append(UTF8String.fromString("1234")) + assert(sb.build() === UTF8String.fromString("abcd1234")) + + // expect to grow an internal buffer + sb.append(UTF8String.fromString("efgijk567890")) + assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) + } +} From 0f93b91a71444a1a938acfd8ea2191c54fb0187c Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Apr 2018 15:47:42 -0600 Subject: [PATCH 300/593] [SPARK-23751][FOLLOW-UP] fix build for scala-2.12 ## What changes were proposed in this pull request? fix build for scala-2.12 ## How was this patch tested? Manual. Author: WeichenXu Closes #21051 from WeichenXu123/fix_build212. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index af8ff64d33ffe..adf8145726711 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -85,7 +85,7 @@ object KolmogorovSmirnovTest { dataset: Dataset[_], sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble) } /** From 682002b6da844ed11324ee5ff4d00fc0294c0b31 Mon Sep 17 00:00:00 2001 From: Patrick Pisciuneri Date: Fri, 13 Apr 2018 09:45:27 +0800 Subject: [PATCH 301/593] [SPARK-23867][SCHEDULER] use droppedCount in logWarning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Get the count of dropped events for output in log message. ## How was this patch tested? The fix is pretty trivial, but `./dev/run-tests` were run and were successful. Please review http://spark.apache.org/contributing.html before opening a pull request. vanzin cloud-fan The contribution is my original work and I license the work to the project under the project’s open source license. Author: Patrick Pisciuneri Closes #20977 from phpisciuneri/fix-log-warning. --- .../main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 7e14938acd8e0..c1fedd63f6a90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -166,7 +166,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp = System.currentTimeMillis() val previous = new java.util.Date(prevLastReportTimestamp) - logWarning(s"Dropped $droppedEvents events from $name since $previous.") + logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } From 14291b061b9b40eadbf4ed442f9a5021b8e09597 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 12 Apr 2018 20:00:25 -0700 Subject: [PATCH 302/593] [SPARK-23748][SS] Fix SS continuous process doesn't support SubqueryAlias issue ## What changes were proposed in this pull request? Current SS continuous doesn't support processing on temp table or `df.as("xxx")`, SS will throw an exception as LogicalPlan not supported, details described in [here](https://issues.apache.org/jira/browse/SPARK-23748). So here propose to add this support. ## How was this patch tested? new UT. Author: jerryshao Closes #21017 from jerryshao/SPARK-23748. --- .../UnsupportedOperationChecker.scala | 2 +- .../continuous/ContinuousSuite.scala | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index b55043c270644..ff9d6d7a7dded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,7 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index f5884b9c8de12..ef74efef156d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -171,6 +171,25 @@ class ContinuousSuite extends ContinuousSuiteBase { "Continuous processing does not support current time operations.")) } + test("subquery alias") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .createOrReplaceTempView("rate") + val test = spark.sql("select value from rate where value > 5") + + testStream(test, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + test("repeatedly restart") { val df = spark.readStream .format("rate") From ab7b961a4fe96ca02b8352d16b0fa80c972b67fc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 13 Apr 2018 11:28:13 +0800 Subject: [PATCH 303/593] [SPARK-23942][PYTHON][SQL] Makes collect in PySpark as action for a query executor listener ## What changes were proposed in this pull request? This PR proposes to add `collect` to a query executor as an action. Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below: ```scala package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener class TestQueryExecutionListener extends QueryExecutionListener with Logging { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { logError("Look at me! I'm 'onSuccess'") } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } } ``` and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener` Other operations in PySpark or Scala side seems fine: ```python >>> sql("SELECT * FROM range(1)").show() ``` ``` 18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' +---+ | id| +---+ | 0| +---+ ``` ```scala scala> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' res1: Array[org.apache.spark.sql.Row] = Array([0]) ``` but .. **Before** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` id 0 0 ``` **After** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` 18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' id 0 0 ``` ## How was this patch tested? I have manually tested as described above and unit test was added. Author: hyukjinkwon Closes #21007 from HyukjinKwon/SPARK-23942. --- python/pyspark/sql/tests.py | 87 ++++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 20 +++-- .../sql/TestQueryExecutionListener.scala | 44 ++++++++++ 3 files changed, 134 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 96c2a776a5049..4e99c8e3c6b10 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,16 +186,12 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ @contextmanager def sql_conf(self, pairs): @@ -204,6 +200,7 @@ def sql_conf(self, pairs): `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() @@ -219,6 +216,18 @@ def sql_conf(self, pairs): else: self.spark.conf.set(key, old_value) + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3066,6 +3075,64 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0aee1d7be5788..917168162b236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collectAsArrowToPython", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3311,14 +3312,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala new file mode 100644 index 0000000000000..d2a6358ee822b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} + +/** + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. + */ +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) +} From 1018be44d6c52cf18e14d84160850063f0e60a1d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 12 Apr 2018 22:30:59 -0700 Subject: [PATCH 304/593] [SPARK-23971] Should not leak Spark sessions across test suites ## What changes were proposed in this pull request? Many suites currently leak Spark sessions (sometimes with stopped SparkContexts) via the thread-local active Spark session and default Spark session. We should attempt to clean these up and detect when this happens to improve the reproducibility of tests. ## How was this patch tested? Existing tests Author: Eric Liang Closes #21058 from ericl/clear-session. --- .../org/apache/spark/SharedSparkSession.java | 9 ++++++-- .../org/apache/spark/sql/SparkSession.scala | 23 +++++++++++++++++-- .../apache/spark/sql/SessionStateSuite.scala | 2 ++ .../spark/sql/test/SharedSparkSession.scala | 22 ++++++++++++++---- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java index 43779878890db..35a250955b282 100644 --- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -42,7 +42,12 @@ public void setUp() throws IOException { @After public void tearDown() { - spark.stop(); - spark = null; + try { + spark.stop(); + spark = null; + } finally { + SparkSession.clearDefaultSession(); + SparkSession.clearActiveSession(); + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b107492fbb330..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} /** @@ -81,6 +81,9 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => + // The call site where this SparkSession was constructed. + private val creationSite: CallSite = Utils.getCallSite() + private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) } @@ -763,7 +766,7 @@ class SparkSession private( @InterfaceStability.Stable -object SparkSession { +object SparkSession extends Logging { /** * Builder for [[SparkSession]]. @@ -1090,4 +1093,20 @@ object SparkSession { } } + private[spark] def cleanupAnyExistingSession(): Unit = { + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + logWarning( + s"""An existing Spark session exists as the active or default session. + |This probably means another suite leaked it. Attempting to stop it before continuing. + |This existing Spark session was created at: + | + |${session.get.creationSite.longForm} + | + """.stripMargin) + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 4efae4c46c2e1..7d1366092d1e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -44,6 +44,8 @@ class SessionStateSuite extends SparkFunSuite { if (activeSession != null) { activeSession.stop() activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e758c865b908f..8968dbf36d507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -60,6 +60,7 @@ trait SharedSparkSession protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() new TestSparkSession(sparkConf) } @@ -92,11 +93,22 @@ trait SharedSparkSession * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } } } From 4b07036799b01894826b47c73142fe282c607a57 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Fri, 13 Apr 2018 13:46:34 +0800 Subject: [PATCH 305/593] [SPARK-23815][CORE] Spark writer dynamic partition overwrite mode may fail to write output on multi level partition ## What changes were proposed in this pull request? Spark introduced new writer mode to overwrite only related partitions in SPARK-20236. While we are using this feature in our production cluster, we found a bug when writing multi-level partitions on HDFS. A simple test case to reproduce this issue: val df = Seq(("1","2","3")).toDF("col1", "col2","col3") df.write.partitionBy("col1","col2").mode("overwrite").save("/my/hdfs/location") If HDFS location "/my/hdfs/location" does not exist, there will be no output. This seems to be caused by the job commit change in SPARK-20236 in HadoopMapReduceCommitProtocol. In the commit job process, the output has been written into staging dir /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2, and then the code calls fs.rename to rename /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2 to /my/hdfs/location/col1=1/col2=2. However, in our case the operation will fail on HDFS because /my/hdfs/location/col1=1 does not exists. HDFS rename can not create directory for more than one level. This does not happen in the new unit test added with SPARK-20236 which uses local file system. We are proposing a fix. When cleaning current partition dir /my/hdfs/location/col1=1/col2=2 before the rename op, if the delete op fails (because /my/hdfs/location/col1=1/col2=2 may not exist), we call mkdirs op to create the parent dir /my/hdfs/location/col1=1 (if the parent dir does not exist) so the following rename op can succeed. Reference: in official HDFS document(https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html), the rename command has precondition "dest must be root, or have a parent that exists" ## How was this patch tested? We have tested this patch on our production cluster and it fixed the problem Author: Fangshi Li Closes #20931 from fangshil/master. --- .../internal/io/HadoopMapReduceCommitProtocol.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6d20ef1f98a3c..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -186,7 +186,17 @@ class HadoopMapReduceCommitProtocol( logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") for (part <- partitionPaths) { val finalPartPath = new Path(path, part) - fs.delete(finalPartPath, true) + if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) { + // According to the official hadoop FileSystem API spec, delete op should assume + // the destination is no longer present regardless of return value, thus we do not + // need to double check if finalPartPath exists before rename. + // Also in our case, based on the spec, delete returns false only when finalPartPath + // does not exist. When this happens, we need to take action if parent of finalPartPath + // also does not exist(e.g. the scenario described on SPARK-23815), because + // FileSystem API spec on rename op says the rename dest(finalPartPath) must have + // a parent that exists, otherwise we may get unexpected result on the rename. + fs.mkdirs(finalPartPath.getParent) + } fs.rename(new Path(stagingDir, part), finalPartPath) } } From 0323e61465ee747c9a57a70e9d6108876499546e Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 13 Apr 2018 00:00:04 -0700 Subject: [PATCH 306/593] [SPARK-23905][SQL] Add UDF weekday ## What changes were proposed in this pull request? Add UDF weekday ## How was this patch tested? A new test Author: yucai Closes #21009 from yucai/SPARK-23905. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 55 +++++++++++++++---- .../expressions/DateExpressionsSuite.scala | 11 ++++ .../resources/sql-tests/inputs/datetime.sql | 2 + .../sql-tests/results/datetime.sql.out | 9 ++- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..131b958239e41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -395,6 +395,7 @@ object FunctionRegistry { expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), + expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 32fdb13afbbfa..b9b2cd5bdb9f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -426,36 +426,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa """, since = "2.3.0") // scalastyle:on line.size.limit -case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfWeek(child: Expression) extends DayWeek { - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType + override protected def nullSafeEval(date: Any): Any = { + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + cal.get(Calendar.DAY_OF_WEEK) + } - @transient private lazy val c = { - Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).", + examples = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 3 + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class WeekDay(child: Expression) extends DayWeek { override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.DAY_OF_WEEK) + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7 } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" + val c = "calWeekDay" ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ }) } } +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient protected lazy val cal: Calendar = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 786266a2c13c0..080ec487cfa6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -211,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) } + test("WeekDay") { + checkEvaluation(WeekDay(Literal.create(null, DateType)), null) + checkEvaluation(WeekDay(Literal(d)), 2) + checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd3..547c2bef02b24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -25,3 +25,5 @@ create temporary view ttf2 as select * from values select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index bbb6851e69c7e..4e1cfa6e48c1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -81,3 +81,10 @@ struct -- !query 8 output 1 2 2 3 + +-- !query 9 +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +5 3 5 NULL 4 From a83ae0d9bc1b8f4909b9338370efe4020079bea7 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 13 Apr 2018 08:43:58 -0700 Subject: [PATCH 307/593] [SPARK-22839][K8S] Refactor to unify driver and executor pod builder APIs ## What changes were proposed in this pull request? Breaks down the construction of driver pods and executor pods in a way that uses a common abstraction for both spark-submit creating the driver and KubernetesClusterSchedulerBackend creating the executor. Encourages more code reuse and is more legible than the older approach. The high-level design is discussed in more detail on the JIRA ticket. This pull request is the implementation of that design with some minor changes in the implementation details. No user-facing behavior should break as a result of this change. ## How was this patch tested? Migrated all unit tests from the old submission steps architecture to the new architecture. Integration tests should not have to change and pass given that this shouldn't change any outward behavior. Author: mcheah Closes #20910 from mccheah/spark-22839-incremental. --- .../org/apache/spark/deploy/k8s/Config.scala | 2 +- .../spark/deploy/k8s/KubernetesConf.scala | 184 +++++++++++++ .../deploy/k8s/KubernetesDriverSpec.scala} | 25 +- .../spark/deploy/k8s/KubernetesUtils.scala | 11 - .../deploy/k8s/MountSecretsBootstrap.scala | 72 ----- ...ConfigurationStep.scala => SparkPod.scala} | 24 +- .../k8s/features/BasicDriverFeatureStep.scala | 136 ++++++++++ .../features/BasicExecutorFeatureStep.scala | 179 +++++++++++++ ...iverKubernetesCredentialsFeatureStep.scala | 216 +++++++++++++++ .../features/DriverServiceFeatureStep.scala | 97 +++++++ .../KubernetesFeatureConfigStep.scala | 71 +++++ .../features/MountSecretsFeatureStep.scala | 62 +++++ .../k8s/submit/DriverConfigOrchestrator.scala | 145 ----------- .../submit/KubernetesClientApplication.scala | 80 +++--- .../k8s/submit/KubernetesDriverBuilder.scala | 56 ++++ .../k8s/submit/KubernetesDriverSpec.scala | 47 ---- .../steps/BasicDriverConfigurationStep.scala | 163 ------------ .../steps/DependencyResolutionStep.scala | 61 ----- .../DriverKubernetesCredentialsStep.scala | 245 ------------------ .../submit/steps/DriverMountSecretsStep.scala | 38 --- .../steps/DriverServiceBootstrapStep.scala | 104 -------- .../cluster/k8s/ExecutorPodFactory.scala | 227 ---------------- .../k8s/KubernetesClusterManager.scala | 12 +- .../KubernetesClusterSchedulerBackend.scala | 20 +- .../k8s/KubernetesExecutorBuilder.scala | 41 +++ .../deploy/k8s/KubernetesConfSuite.scala | 175 +++++++++++++ .../BasicDriverFeatureStepSuite.scala | 153 +++++++++++ .../BasicExecutorFeatureStepSuite.scala | 179 +++++++++++++ ...bernetesCredentialsFeatureStepSuite.scala} | 101 +++++--- .../DriverServiceFeatureStepSuite.scala | 227 ++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 61 +++++ .../MountSecretsFeatureStepSuite.scala} | 29 ++- .../spark/deploy/k8s/submit/ClientSuite.scala | 216 +++++++-------- .../DriverConfigOrchestratorSuite.scala | 131 ---------- .../submit/KubernetesDriverBuilderSuite.scala | 102 ++++++++ .../BasicDriverConfigurationStepSuite.scala | 122 --------- .../steps/DependencyResolutionStepSuite.scala | 69 ----- .../DriverServiceBootstrapStepSuite.scala | 180 ------------- .../cluster/k8s/ExecutorPodFactorySuite.scala | 195 -------------- ...bernetesClusterSchedulerBackendSuite.scala | 37 ++- .../k8s/KubernetesExecutorBuilderSuite.scala | 75 ++++++ 41 files changed, 2289 insertions(+), 2081 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala rename resource-managers/kubernetes/core/src/{test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala => main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala} (57%) delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverConfigurationStep.scala => SparkPod.scala} (64%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverKubernetesCredentialsStepSuite.scala => features/DriverKubernetesCredentialsFeatureStepSuite.scala} (67%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverMountSecretsStepSuite.scala => features/MountSecretsFeatureStepSuite.scala} (64%) delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 82f6c714f3555..4086970ffb256 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -167,5 +167,5 @@ private[spark] object Config extends Logging { val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." - val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala new file mode 100644 index 0000000000000..77b634ddfabcc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.internal.config.ConfigEntry + +private[spark] sealed trait KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark driver. + */ +private[spark] case class KubernetesDriverSpecificConf( + mainAppResource: Option[MainAppResource], + mainClass: String, + appName: String, + appArgs: Seq[String]) extends KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark executor. + */ +private[spark] case class KubernetesExecutorSpecificConf( + executorId: String, + driverPod: Pod) + extends KubernetesRoleSpecificConf + +/** + * Structure containing metadata for Kubernetes logic to build Spark pods. + */ +private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( + sparkConf: SparkConf, + roleSpecificConf: T, + appResourceNamePrefix: String, + appId: String, + roleLabels: Map[String, String], + roleAnnotations: Map[String, String], + roleSecretNamesToMountPaths: Map[String, String], + roleEnvs: Map[String, String]) { + + def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + + def sparkJars(): Seq[String] = sparkConf + .getOption("spark.jars") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def sparkFiles(): Seq[String] = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets(): Seq[LocalObjectReference] = { + sparkConf + .get(IMAGE_PULL_SECRETS) + .map(_.split(",")) + .getOrElse(Array.empty[String]) + .map(_.trim) + .map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + } + + def nodeSelector(): Map[String, String] = + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) + + def get(conf: String): String = sparkConf.get(conf) + + def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue) + + def getOption(key: String): Option[String] = sparkConf.getOption(key) +} + +private[spark] object KubernetesConf { + def createDriverConf( + sparkConf: SparkConf, + appName: String, + appResourceNamePrefix: String, + appId: String, + mainAppResource: Option[MainAppResource], + mainClass: String, + appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + val sparkConfWithMainAppJar = sparkConf.clone() + mainAppResource.foreach { + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + } + + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + val driverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + + KubernetesConf( + sparkConfWithMainAppJar, + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + appResourceNamePrefix, + appId, + driverLabels, + driverAnnotations, + driverSecretNamesToMountPaths, + driverEnvs) + } + + def createExecutorConf( + sparkConf: SparkConf, + executorId: String, + appId: String, + driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorCustomLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorCustomLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + val executorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorCustomLabels + val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnv = sparkConf.getExecutorEnv.toMap + + KubernetesConf( + sparkConf.clone(), + KubernetesExecutorSpecificConf(executorId, driverPod), + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appId, + executorLabels, + executorAnnotations, + executorSecrets, + executorEnv) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala similarity index 57% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala index cf41b22e241af..0c5ae022f4070 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -16,21 +16,16 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference +import io.fabric8.kubernetes.api.model.HasMetadata -import org.apache.spark.SparkFunSuite - -class KubernetesUtilsTest extends SparkFunSuite { - - test("testParseImagePullSecrets") { - val noSecrets = KubernetesUtils.parseImagePullSecrets(None) - assert(noSecrets === Nil) - - val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) - assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) - - val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) - assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) - } +private[spark] case class KubernetesDriverSpec( + pod: SparkPod, + driverKubernetesResources: Seq[HasMetadata], + systemProperties: Map[String, String]) +private[spark] object KubernetesDriverSpec { + def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( + SparkPod.initialPod(), + Seq.empty, + initialProps) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5b2bb819cdb14..ee629068ad90d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -37,17 +37,6 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } - /** - * Parses comma-separated list of imagePullSecrets into K8s-understandable format - */ - def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { - imagePullSecrets match { - case Some(secretsCommaSeparated) => - secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList - case None => Nil - } - } - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala deleted file mode 100644 index c35e7db51d407..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} - -/** - * Bootstraps a driver or executor container or an init-container with needed secrets mounted. - */ -private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { - - /** - * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. - * - * @param pod the pod into which the secret volumes are being added. - * @return the updated pod with the secret volumes added. - */ - def addSecretVolumes(pod: Pod): Pod = { - var podBuilder = new PodBuilder(pod) - secretNamesToMountPaths.keys.foreach { name => - podBuilder = podBuilder - .editOrNewSpec() - .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() - .endSpec() - } - - podBuilder.build() - } - - /** - * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the - * given container. - * - * @param container the container into which the secret volumes are being mounted. - * @return the updated container with the secrets mounted. - */ - def mountSecrets(container: Container): Container = { - var containerBuilder = new ContainerBuilder(container) - secretNamesToMountPaths.foreach { case (name, path) => - containerBuilder = containerBuilder - .addNewVolumeMount() - .withName(secretVolumeName(name)) - .withMountPath(path) - .endVolumeMount() - } - - containerBuilder.build() - } - - private def secretVolumeName(secretName: String): String = { - secretName + "-volume" - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala similarity index 64% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 17614e040e587..345dd117fd35f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -14,17 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -/** - * Represents a step in configuring the Spark driver pod. - */ -private[spark] trait DriverConfigurationStep { +private[spark] case class SparkPod(pod: Pod, container: Container) - /** - * Apply some transformation to the previous state of the driver to add a new feature to it. - */ - def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +private[spark] object SparkPod { + def initialPod(): SparkPod = { + SparkPod( + new PodBuilder() + .withNewMetadata() + .endMetadata() + .withNewSpec() + .endSpec() + .build(), + new ContainerBuilder().build()) + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala new file mode 100644 index 0000000000000..07bdccbe0479d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class BasicDriverFeatureStep( + conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"${conf.appResourceNamePrefix}-driver") + + private val driverContainerImage = conf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) + + // CPU settings + private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadMiB = conf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configurePod(pod: SparkPod): SparkPod = { + val driverCustomEnvs = conf.roleEnvs + .toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(pod.container) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverContainerImage) + .withImagePullPolicy(conf.imagePullPolicy()) + .addAllToEnv(driverCustomEnvs.asJava) + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryQuantity) + .endResources() + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", conf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(conf.roleSpecificConf.appArgs: _*) + .build() + + val driverPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(driverPodName) + .addToLabels(conf.roleLabels.asJava) + .addToAnnotations(conf.roleAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(conf.nodeSelector().asJava) + .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(driverPod, driverContainer) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val additionalProps = mutable.Map( + KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, + "spark.app.id" -> conf.appId, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") + + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkJars()) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkFiles()) + if (resolvedSparkJars.nonEmpty) { + additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + } + additionalProps.toMap + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala new file mode 100644 index 0000000000000..d22097587aafe --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) + extends KubernetesFeatureConfigStep { + + // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf + private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) + private val executorContainerImage = kubernetesConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val blockManagerPort = kubernetesConf + .sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + + private val driverUrl = RpcEndpointAddress( + kubernetesConf.get("spark.driver.host"), + kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) + private val executorMemoryString = kubernetesConf.get( + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = kubernetesConf + .get(EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = + if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } + private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def configurePod(pod: SparkPod): SparkPod = { + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCoresRequest) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = kubernetesConf + .get(EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ + kubernetesConf.roleEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder(pod.container) + .withName("executor") + .withImage(executorContainerImage) + .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .addToArgs("executor") + .build() + val containerWithLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + val driverPod = kubernetesConf.roleSpecificConf.driverPod + val executorPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(name) + .withLabels(kubernetesConf.roleLabels.asJava) + .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .editOrNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(executorPod, containerWithLimitCores) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala new file mode 100644 index 0000000000000..ff5ad6673b309 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) + extends KubernetesFeatureConfigStep { + // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. + // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. + + private val maybeMountedOAuthTokenFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + private val oauthTokenBase64 = kubernetesConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + + private val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + private val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + private val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder. + private val shouldMountSecret = oauthTokenBase64.isDefined || + caCertDataBase64.isDefined || + clientKeyDataBase64.isDefined || + clientCertDataBase64.isDefined + + private val driverCredentialsSecretName = + s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + + override def configurePod(pod: SparkPod): SparkPod = { + if (!shouldMountSecret) { + pod.copy( + pod = driverServiceAccount.map { account => + new PodBuilder(pod.pod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(pod.pod)) + } else { + val driverPodWithMountedKubernetesCredentials = + new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret() + .endVolume() + .endSpec() + .build() + + val driverContainerWithMountedSecretVolume = + new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val redactedTokens = kubernetesConf.sparkConf.getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) + .toMap + .mapValues( _ => "") + redactedTokens ++ + resolvedMountedCaCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientKeyFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedOAuthTokenFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (shouldMountSecret) { + Seq(createCredentialsSecret()) + } else { + Seq.empty + } + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + kubernetesConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + private def createCredentialsSecret(): Secret = { + val allSecretData = + resolveSecretData( + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + new SecretBuilder() + .withNewMetadata() + .withName(driverCredentialsSecretName) + .endMetadata() + .withData(allSecretData.asJava) + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala new file mode 100644 index 0000000000000..f2d7bbd08f305 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, SystemClock} + +private[spark] class DriverServiceFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + clock: Clock = new SystemClock) + extends KubernetesFeatureConfigStep with Logging { + import DriverServiceFeatureStep._ + + require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + private val driverPort = kubernetesConf.sparkConf.getInt( + "spark.driver.port", DEFAULT_DRIVER_PORT) + private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + + override def configurePod(pod: SparkPod): SparkPod = pod + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + Map(DRIVER_HOST_KEY -> driverHostname, + "spark.driver.port" -> driverPort.toString, + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> + driverBlockManagerPort.toString) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(kubernetesConf.roleLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + Seq(driverService) + } +} + +private[spark] object DriverServiceFeatureStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala new file mode 100644 index 0000000000000..4c1be3bb13293 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.HasMetadata + +import org.apache.spark.deploy.k8s.SparkPod + +/** + * A collection of functions that together represent a "feature" in pods that are launched for + * Spark drivers and executors. + */ +private[spark] trait KubernetesFeatureConfigStep { + + /** + * Apply modifications on the given pod in accordance to this feature. This can include attaching + * volumes, adding environment variables, and adding labels/annotations. + *

    + * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod + * object. So this is correct: + *

    +   * {@code val configuredPod = new PodBuilder(pod.pod)
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder(pod.container)
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + * This is incorrect: + *
    +   * {@code val configuredPod = new PodBuilder() // Loses the original state
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder() // Loses the original state
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + */ + def configurePod(pod: SparkPod): SparkPod + + /** + * Return any system properties that should be set on the JVM in accordance to this feature. + */ + def getAdditionalPodSystemProperties(): Map[String, String] + + /** + * Return any additional Kubernetes resources that should be added to support this feature. Only + * applicable when creating the driver in cluster mode. + */ + def getAdditionalKubernetesResources(): Seq[HasMetadata] +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala new file mode 100644 index 0000000000000..97fa9499b2edb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class MountSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedVolumes = kubernetesConf + .roleSecretNamesToMountPaths + .keys + .map(secretName => + new VolumeBuilder() + .withName(secretVolumeName(secretName)) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .build()) + val podWithVolumes = new PodBuilder(pod.pod) + .editOrNewSpec() + .addToVolumes(addedVolumes.toSeq: _*) + .endSpec() + .build() + val addedVolumeMounts = kubernetesConf + .roleSecretNamesToMountPaths + .map { + case (secretName, mountPath) => + new VolumeMountBuilder() + .withName(secretVolumeName(secretName)) + .withMountPath(mountPath) + .build() + } + val containerWithMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(addedVolumeMounts.toSeq: _*) + .build() + SparkPod(podWithVolumes, containerWithMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def secretVolumeName(secretName: String): String = s"$secretName-volume" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala deleted file mode 100644 index b4d3f04a1bc32..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils - -/** - * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to - * configure the Spark driver pod. The returned steps will be applied one by one in the given - * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. - */ -private[spark] class DriverConfigOrchestrator( - kubernetesAppId: String, - kubernetesResourceNamePrefix: String, - mainAppResource: Option[MainAppResource], - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) { - - // The resource name prefix is derived from the Spark application name, making it easy to connect - // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the - // application the user submitted. - - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - - def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - - val allDriverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> kubernetesAppId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - - val initialSubmissionStep = new BasicDriverConfigurationStep( - kubernetesAppId, - kubernetesResourceNamePrefix, - allDriverLabels, - imagePullPolicy, - appName, - mainClass, - appArgs, - sparkConf) - - val serviceBootstrapStep = new DriverServiceBootstrapStep( - kubernetesResourceNamePrefix, - allDriverLabels, - sparkConf, - new SystemClock) - - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - sparkConf, kubernetesResourceNamePrefix) - - val additionalMainAppJar = if (mainAppResource.nonEmpty) { - val mayBeResource = mainAppResource.get match { - case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => - Some(resource) - case _ => None - } - mayBeResource - } else { - None - } - - val sparkJars = sparkConf.getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty[String]) ++ - additionalMainAppJar.toSeq - val sparkFiles = sparkConf.getOption("spark.files") - .map(_.split(",")) - .getOrElse(Array.empty[String]) - - // TODO(SPARK-23153): remove once submission client local dependencies are supported. - if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { - throw new SparkException("The Kubernetes mode does not yet support referencing application " + - "dependencies in the local file system.") - } - - val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Seq(new DependencyResolutionStep( - sparkJars, - sparkFiles)) - } else { - Nil - } - - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq( - initialSubmissionStep, - serviceBootstrapStep, - kubernetesCredentialsStep) ++ - dependencyResolutionStep ++ - mountSecretsStep - } - - private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme == "file" - } - } - - private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme != "local" - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index e16d1add600b2..a97f5650fb869 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,12 +27,10 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -43,9 +41,9 @@ import org.apache.spark.util.Utils * @param driverArgs arguments to the driver */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], - mainClass: String, - driverArgs: Array[String]) + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) private[spark] object ClientArguments { @@ -80,8 +78,9 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * - * @param submissionSteps steps that collectively configure the driver - * @param sparkConf the submission client Spark configuration + * @param builder Responsible for building the base driver pod based on a composition of + * implemented features. + * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete @@ -89,31 +88,21 @@ private[spark] object ClientArguments { * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( - submissionSteps: Seq[DriverConfigurationStep], - sparkConf: SparkConf, + builder: KubernetesDriverBuilder, + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, watcher: LoggingPodStatusWatcher, kubernetesResourceNamePrefix: String) extends Logging { - /** - * Run command that initializes a DriverSpec that will be updated after each - * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec - * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources - */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) - // submissionSteps contain steps necessary to take, to resolve varying - // client arguments that are passed in, created by orchestrator - for (nextStep <- submissionSteps) { - currentDriverSpec = nextStep.configureDriver(currentDriverSpec) - } + val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" - val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap - val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) + val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container) .addNewEnv() .withName(ENV_SPARK_CONF_DIR) .withValue(SPARK_CONF_DIR_INTERNAL) @@ -123,7 +112,7 @@ private[spark] class Client( .withMountPath(SPARK_CONF_DIR_INTERNAL) .endVolumeMount() .build() - val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod) .editSpec() .addToContainers(resolvedDriverContainer) .addNewVolume() @@ -141,12 +130,10 @@ private[spark] class Client( .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { - if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = - currentDriverSpec.otherKubernetesResources ++ Seq(configMap) - addDriverOwnerReference(createdDriverPod, otherKubernetesResources) - kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() - } + val otherKubernetesResources = + resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap) + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } catch { case NonFatal(e) => kubernetesClient.pods().delete(createdDriverPod) @@ -180,20 +167,17 @@ private[spark] class Client( } // Build a Config Map that will house spark conf properties in a single file for spark-submit - private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = { val properties = new Properties() - conf.getAll.foreach { case (k, v) => + conf.foreach { case (k, v) => properties.setProperty(k, v) } val propertiesWriter = new StringWriter() properties.store(propertiesWriter, s"Java properties built from Kubernetes config map with name: $configMapName") - - val namespace = conf.get(KUBERNETES_NAMESPACE) new ConfigMapBuilder() .withNewMetadata() .withName(configMapName) - .withNamespace(namespace) .endMetadata() .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) .build() @@ -211,7 +195,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // For constructing the app ID, we can't use the Spark application name, as the app ID is going // to be added as a label to group resources belonging to the same application. Label values are // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate @@ -219,10 +203,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + val kubernetesConf = KubernetesConf.createDriverConf( + sparkConf, + appName, + kubernetesResourceNamePrefix, + kubernetesAppId, + clientArguments.mainAppResource, + clientArguments.mainClass, + clientArguments.driverArgs) + val builder = new KubernetesDriverBuilder + val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -230,15 +223,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val orchestrator = new DriverConfigOrchestrator( - kubernetesAppId, - kubernetesResourceNamePrefix, - clientArguments.mainAppResource, - appName, - clientArguments.mainClass, - clientArguments.driverArgs, - sparkConf) - Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, Some(namespace), @@ -247,8 +231,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - orchestrator.getAllConfigurationSteps, - sparkConf, + builder, + kubernetesConf, kubernetesClient, waitForAppCompletion, appName, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala new file mode 100644 index 0000000000000..c7579ed8cb689 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesDriverBuilder( + provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + new BasicDriverFeatureStep(_), + provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) + => DriverKubernetesCredentialsFeatureStep = + new DriverKubernetesCredentialsFeatureStep(_), + provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + new DriverServiceFeatureStep(_), + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountSecretsFeatureStep) = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideCredentialsStep(kubernetesConf), + provideServiceStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + for (feature <- allFeatures) { + val configuredPod = feature.configurePod(spec.pod) + val addedSystemProperties = feature.getAdditionalPodSystemProperties() + val addedResources = feature.getAdditionalKubernetesResources() + spec = KubernetesDriverSpec( + configuredPod, + spec.driverKubernetesResources ++ addedResources, + spec.systemProperties ++ addedSystemProperties) + } + spec + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala deleted file mode 100644 index db13f09387ef9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} - -import org.apache.spark.SparkConf - -/** - * Represents the components and characteristics of a Spark driver. The driver can be considered - * as being comprised of the driver pod itself, any other Kubernetes resources that the driver - * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver - * container should be operated on via the specific field of this case class as opposed to trying - * to edit the container directly on the pod. The driver container should be attached at the - * end of executing all submission steps. - */ -private[spark] case class KubernetesDriverSpec( - driverPod: Pod, - driverContainer: Container, - otherKubernetesResources: Seq[HasMetadata], - driverSparkConf: SparkConf) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { - KubernetesDriverSpec( - // Set new metadata and a new spec so that submission steps can use - // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - new ContainerBuilder().build(), - Seq.empty[HasMetadata], - initialSparkConf.clone()) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala deleted file mode 100644 index fcb1db8008053..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} -import org.apache.spark.launcher.SparkLauncher - -/** - * Performs basic configuration for the driver pod. - */ -private[spark] class BasicDriverConfigurationStep( - kubernetesAppId: String, - resourceNamePrefix: String, - driverLabels: Map[String, String], - imagePullPolicy: String, - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) extends DriverConfigurationStep { - - private val driverPodName = sparkConf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$resourceNamePrefix-driver") - - private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - - private val driverContainerImage = sparkConf - .get(DRIVER_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver container image")) - - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - - // CPU settings - private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) - - // Memory settings - private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val memoryOverheadMiB = sparkConf - .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) - private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(classPath) - .build() - } - - val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), - s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + - " Spark bookkeeping operations.") - - val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq - .map { env => - new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - } - - val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - - val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - - val driverCpuQuantity = new QuantityBuilder(false) - .withAmount(driverCpuCores) - .build() - val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryWithOverheadMiB}Mi") - .build() - val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => - ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) - } - - val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) - .withName(DRIVER_CONTAINER_NAME) - .withImage(driverContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(driverCustomEnvs.asJava) - .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_BIND_ADDRESS) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .endEnv() - .withNewResources() - .addToRequests("cpu", driverCpuQuantity) - .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryQuantity) - .addToLimits(maybeCpuLimitQuantity.toMap.asJava) - .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - - val driverContainer = appArgs.toList match { - case "" :: Nil | Nil => driverContainerWithoutArgs.build() - case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() - } - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - val baseDriverPod = new PodBuilder(driverSpec.driverPod) - .editOrNewMetadata() - .withName(driverPodName) - .addToLabels(driverLabels.asJava) - .addToAnnotations(driverAnnotations.asJava) - .endMetadata() - .withNewSpec() - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) - // to set the config variables to allow client-mode spark-submit from driver - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - - driverSpec.copy( - driverPod = baseDriverPod, - driverSparkConf = resolvedSparkConf, - driverContainer = driverContainer) - } - -} - diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala deleted file mode 100644 index 43de329f239ad..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import io.fabric8.kubernetes.api.model.ContainerBuilder - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Step that configures the classpath, spark.jars, and spark.files for the driver given that the - * user may provide remote files or files with local:// schemes. - */ -private[spark] class DependencyResolutionStep( - sparkJars: Seq[String], - sparkFiles: Seq[String]) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) - val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) - - val sparkConf = driverSpec.driverSparkConf.clone() - if (resolvedSparkJars.nonEmpty) { - sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) - } - val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { - new ContainerBuilder(driverSpec.driverContainer) - .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedSparkJars.mkString(File.pathSeparator)) - .endEnv() - .build() - } else { - driverSpec.driverContainer - } - - driverSpec.copy( - driverContainer = resolvedDriverContainer, - driverSparkConf = sparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala deleted file mode 100644 index 2424e63999a82..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials - * to request executors. - */ -private[spark] class DriverKubernetesCredentialsStep( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { - - private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") - private val maybeMountedClientKeyFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") - private val maybeMountedClientCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") - private val maybeMountedCaCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") - private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverSparkConf = driverSpec.driverSparkConf.clone() - - val oauthTokenBase64 = submissionSparkConf - .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") - .map { token => - BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) - } - val caCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - "Driver CA cert file") - val clientKeyDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - "Driver client key file") - val clientCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - "Driver client cert file") - - val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( - driverSparkConf, - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val kubernetesCredentialsSecret = createCredentialsSecret( - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .addNewVolume() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() - .endVolume() - .endSpec() - .build() - }.getOrElse( - driverServiceAccount.map { account => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .withServiceAccount(account) - .withServiceAccountName(account) - .endSpec() - .build() - }.getOrElse(driverSpec.driverPod) - ) - - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => - new ContainerBuilder(driverSpec.driverContainer) - .addNewVolumeMount() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) - .endVolumeMount() - .build() - }.getOrElse(driverSpec.driverContainer) - - driverSpec.copy( - driverPod = driverPodWithMountedKubernetesCredentials, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, - driverSparkConf = driverSparkConfWithCredentialsLocations, - driverContainer = driverContainerWithMountedSecretVolume) - } - - private def createCredentialsSecret( - driverOAuthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): Option[Secret] = { - val allSecretData = - resolveSecretData( - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ - resolveSecretData( - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ - resolveSecretData( - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ - resolveSecretData( - driverOAuthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) - - if (allSecretData.isEmpty) { - None - } else { - Some(new SecretBuilder() - .withNewMetadata() - .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") - .endMetadata() - .withData(allSecretData.asJava) - .build()) - } - } - - private def setDriverPodKubernetesCredentialLocations( - driverSparkConf: SparkConf, - driverOauthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): SparkConf = { - val resolvedMountedOAuthTokenFile = resolveSecretLocation( - maybeMountedOAuthTokenFile, - driverOauthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) - val resolvedMountedClientKeyFile = resolveSecretLocation( - maybeMountedClientKeyFile, - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_PATH) - val resolvedMountedClientCertFile = resolveSecretLocation( - maybeMountedClientCertFile, - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_PATH) - val resolvedMountedCaCertFile = resolveSecretLocation( - maybeMountedCaCertFile, - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_PATH) - - val sparkConfWithCredentialLocations = driverSparkConf - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - resolvedMountedCaCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - resolvedMountedClientKeyFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - resolvedMountedClientCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", - resolvedMountedOAuthTokenFile) - - // Redact all OAuth token values - sparkConfWithCredentialLocations - .getAll - .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) - .foreach { - sparkConfWithCredentialLocations.set(_, "") - } - sparkConfWithCredentialLocations - } - - private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { - submissionSparkConf.getOption(conf) - .map(new File(_)) - .map { file => - require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", - fileType, file.getAbsolutePath)) - BaseEncoding.base64().encode(Files.toByteArray(file)) - } - } - - private def resolveSecretLocation( - mountedUserSpecified: Option[String], - valueMountedFromSubmitter: Option[String], - mountedCanonicalLocation: String): Option[String] = { - mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => - mountedCanonicalLocation - }) - } - - /** - * Resolve a Kubernetes secret data entry from an optional client credential used by the - * driver to talk to the Kubernetes API server. - * - * @param userSpecifiedCredential the optional user-specified client credential. - * @param secretName name of the Kubernetes secret storing the client credential. - * @return a secret data entry in the form of a map from the secret name to the secret data, - * which may be empty if the user-specified credential is empty. - */ - private def resolveSecretData( - userSpecifiedCredential: Option[String], - secretName: String): Map[String, String] = { - userSpecifiedCredential.map { valueBase64 => - Map(secretName -> valueBase64) - }.getOrElse(Map.empty[String, String]) - } - - private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { - new OptionSettableSparkConf(sparkConf) - } -} - -private class OptionSettableSparkConf(sparkConf: SparkConf) { - def setOption(configEntry: String, option: Option[String]): SparkConf = { - option.foreach { opt => - sparkConf.set(configEntry, opt) - } - sparkConf - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala deleted file mode 100644 index 91e9a9f211335..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * A driver configuration step for mounting user-specified secrets onto user-specified paths. - * - * @param bootstrap a utility actually handling mounting of the secrets. - */ -private[spark] class DriverMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) - val container = bootstrap.mountSecrets(driverSpec.driverContainer) - driverSpec.copy( - driverPod = pod, - driverContainer = container - ) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala deleted file mode 100644 index 34af7cde6c1a9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.ServiceBuilder - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.Logging -import org.apache.spark.util.Clock - -/** - * Allows the driver to be reachable by executor pods through a headless service. The service's - * ports should correspond to the ports that the executor will reach the pod at for RPC. - */ -private[spark] class DriverServiceBootstrapStep( - resourceNamePrefix: String, - driverLabels: Map[String, String], - sparkConf: SparkConf, - clock: Clock) extends DriverConfigurationStep with Logging { - - import DriverServiceBootstrapStep._ - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, - s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + - "address is managed and set to the driver pod's IP address.") - require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, - s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + - "managed via a Kubernetes service.") - - val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" - val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { - preferredServiceName - } else { - val randomServiceId = clock.getTimeMillis() - val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" - logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + - s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + - s"$shorterServiceName as the driver service's name.") - shorterServiceName - } - - val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) - val driverService = new ServiceBuilder() - .withNewMetadata() - .withName(resolvedServiceName) - .endMetadata() - .withNewSpec() - .withClusterIP("None") - .withSelector(driverLabels.asJava) - .addNewPort() - .withName(DRIVER_PORT_NAME) - .withPort(driverPort) - .withNewTargetPort(driverPort) - .endPort() - .addNewPort() - .withName(BLOCK_MANAGER_PORT_NAME) - .withPort(driverBlockManagerPort) - .withNewTargetPort(driverBlockManagerPort) - .endPort() - .endSpec() - .build() - - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .set(DRIVER_HOST_KEY, driverHostname) - .set("spark.driver.port", driverPort.toString) - .set( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) - - driverSpec.copy( - driverSparkConf = resolvedSparkConf, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) - } -} - -private[spark] object DriverServiceBootstrapStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key - val DRIVER_SVC_POSTFIX = "-driver-svc" - val MAX_SERVICE_NAME_LENGTH = 63 -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala deleted file mode 100644 index 8607d6fba3234..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.util.Utils - -/** - * A factory class for bootstrapping and creating executor pods with the given bootstrapping - * components. - * - * @param sparkConf Spark configuration - * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto - * user-specified paths into the executor container - */ -private[spark] class ExecutorPodFactory( - sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap]) { - - private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - - private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - - private val executorAnnotations = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - private val nodeSelector = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_NODE_SELECTOR_PREFIX) - - private val executorContainerImage = sparkConf - .get(EXECUTOR_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor container image")) - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - private val blockManagerPort = sparkConf - .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - - private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - - private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) - private val executorMemoryString = sparkConf.get( - EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) - - private val memoryOverheadMiB = sparkConf - .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - - private val executorCores = sparkConf.getInt("spark.executor.cores", 1) - private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { - sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get - } else { - executorCores.toString - } - private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod = { - val name = s"$executorPodNamePrefix-exec-$executorId" - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - // hostname must be no longer than 63 characters, so take the last 63 characters of the pod - // name as the hostname. This preserves uniqueness since the end of name contains - // executorId - val hostname = name.substring(Math.max(0, name.length - 63)) - val resolvedExecutorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> applicationId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorLabels - val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") - .build() - val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCoresRequest) - .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = sparkConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() - } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - (ENV_EXECUTOR_CORES, executorCores.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, applicationId), - // This is to set the SPARK_CONF_DIR to be /opt/spark/conf - (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq - val requiredPorts = Seq( - (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) - .map { case (name, port) => - new ContainerPortBuilder() - .withName(name) - .withContainerPort(port) - .build() - } - - val executorContainer = new ContainerBuilder() - .withName("executor") - .withImage(executorContainerImage) - .withImagePullPolicy(imagePullPolicy) - .withNewResources() - .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryQuantity) - .addToRequests("cpu", executorCpuQuantity) - .endResources() - .addAllToEnv(executorEnv.asJava) - .withPorts(requiredPorts.asJava) - .addToArgs("executor") - .build() - - val executorPod = new PodBuilder() - .withNewMetadata() - .withName(name) - .withLabels(resolvedExecutorLabels.asJava) - .withAnnotations(executorAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withHostname(hostname) - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val containerWithLimitCores = executorLimitCores.map { limitCores => - val executorCpuLimitQuantity = new QuantityBuilder(false) - .withAmount(limitCores) - .build() - new ContainerBuilder(executorContainer) - .editResources() - .addToLimits("cpu", executorCpuLimitQuantity) - .endResources() - .build() - }.getOrElse(executorContainer) - - val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = - mountSecretsBootstrap.map { bootstrap => - (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) - }.getOrElse((executorPod, containerWithLimitCores)) - - - new PodBuilder(maybeSecretsMountedPod) - .editSpec() - .addToContainers(maybeSecretsMountedContainer) - .endSpec() - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ff5f6801da2a3..0ea80dfbc0d97 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -48,12 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit scheduler: TaskScheduler): SchedulerBackend = { val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -62,8 +56,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) - val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( @@ -71,7 +63,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - executorPodFactory, + new KubernetesExecutorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 9de4b16c30d3c..d86664c81071b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorPodFactory: ExecutorPodFactory, + executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, allocatorExecutor: ScheduledExecutorService, requestExecutorsService: ExecutorService) @@ -115,14 +116,19 @@ private[spark] class KubernetesClusterSchedulerBackend( for (_ <- 0 until math.min( currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorPod = executorPodFactory.createExecutorPod( + val executorConf = KubernetesConf.createExecutorConf( + conf, executorId, applicationId(), - driverUrl, - conf.getExecutorEnv, - driverPod, - currentNodeToLocalTaskCount) - executorsToAllocate(executorId) = executorPod + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + + executorsToAllocate(executorId) = podWithAttachedContainer logInfo( s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala new file mode 100644 index 0000000000000..22568fe7ea3be --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesExecutorBuilder( + provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_), + provideSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + var executorPod = SparkPod.initialPod() + for (feature <- allFeatures) { + executorPod = feature.configurePod(executorPod) + } + executorPod + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala new file mode 100644 index 0000000000000..f10202f7a3546 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource + +class KubernetesConfSuite extends SparkFunSuite { + + private val APP_NAME = "test-app" + private val RESOURCE_NAME_PREFIX = "prefix" + private val APP_ID = "test-id" + private val MAIN_CLASS = "test-class" + private val APP_ARGS = Array("arg1", "arg2") + private val CUSTOM_LABELS = Map( + "customLabel1Key" -> "customLabel1Value", + "customLabel2Key" -> "customLabel2Value") + private val CUSTOM_ANNOTATIONS = Map( + "customAnnotation1Key" -> "customAnnotation1Value", + "customAnnotation2Key" -> "customAnnotation2Value") + private val SECRET_NAMES_TO_MOUNT_PATHS = Map( + "secret1" -> "/mnt/secrets/secret1", + "secret2" -> "/mnt/secrets/secret2") + private val CUSTOM_ENVS = Map( + "customEnvKey1" -> "customEnvValue1", + "customEnvKey2" -> "customEnvValue2") + private val DRIVER_POD = new PodBuilder().build() + private val EXECUTOR_ID = "executor-id" + + test("Basic driver translated fields.") { + val sparkConf = new SparkConf(false) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.appId === APP_ID) + assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) + assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) + assert(conf.roleSpecificConf.appName === APP_NAME) + assert(conf.roleSpecificConf.mainAppResource.isEmpty) + assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) + assert(conf.roleSpecificConf.appArgs === APP_ARGS) + } + + test("Creating driver conf with and without the main app jar influences spark.jars") { + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) + val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppJar, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") + .split(",") + === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) + val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + } + + test("Resolve driver labels, annotations, secret mount paths, and envs.") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) + } + CUSTOM_ENVS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) + } + + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.roleLabels === Map( + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ + CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleEnvs === CUSTOM_ENVS) + } + + test("Basic executor translated fields.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) + assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + } + + test("Image pull secrets.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false) + .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.imagePullSecrets() === + Seq( + new LocalObjectReferenceBuilder().withName("my-secret-1").build(), + new LocalObjectReferenceBuilder().withName("my-secret-2").build())) + } + + test("Set executor labels, annotations, and secrets") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) + } + + val conf = KubernetesConf.createExecutorConf( + sparkConf, + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleLabels === Map( + SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..eee85b8baa730 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class BasicDriverFeatureStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) + private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ENVS = Map( + DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, + DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + + test("Check the pod respects all configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + DRIVER_ENVS) + + val featureStep = new BasicDriverFeatureStep(kubernetesConf) + val basePod = SparkPod.initialPod() + val configuredPod = featureStep.configurePod(basePod) + + assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getImage === "spark-driver:latest") + assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + + assert(configuredPod.container.getEnv.size === 3) + val envs = configuredPod.container + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) + assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + + assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === + TEST_IMAGE_PULL_SECRET_OBJECTS) + + assert(configuredPod.container.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = configuredPod.container.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "456Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = configuredPod.pod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") + + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") + assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) + } + + test("Additional system properties resolve jars and set cluster-mode confs.") { + val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") + val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .setJars(allJars) + .set("spark.files", allFiles.mkString(",")) + .set(CONTAINER_IMAGE, "spark-driver:latest") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty) + val step = new BasicDriverFeatureStep(kubernetesConf) + val additionalProperties = step.getAdditionalPodSystemProperties() + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true", + "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + assert(additionalProperties === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala new file mode 100644 index 0000000000000..a764f7630b5c8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +class BasicExecutorFeatureStepSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + + private val APP_ID = "app-id" + private val DRIVER_HOSTNAME = "localhost" + private val DRIVER_PORT = 7098 + private val DRIVER_ADDRESS = RpcEndpointAddress( + DRIVER_HOSTNAME, + DRIVER_PORT.toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + private val DRIVER_POD_NAME = "driver-pod" + + private val DRIVER_POD_UID = "driver-uid" + private val RESOURCE_NAME_PREFIX = "base" + private val EXECUTOR_IMAGE = "executor-image" + private val LABELS = Map("label1key" -> "label1value") + private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") + private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + private val DRIVER_POD = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .withUid(DRIVER_POD_UID) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) + .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set("spark.driver.host", DRIVER_HOSTNAME) + .set("spark.driver.port", DRIVER_PORT.toString) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + } + + test("basic executor pod has reasonable defaults") { + val step = new BasicExecutorFeatureStep( + KubernetesConf( + baseConf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + val executor = step.configurePod(SparkPod.initialPod()) + + // The executor pod name and default labels. + assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") + assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.container.getImage === EXECUTOR_IMAGE) + assert(executor.container.getVolumeMounts.isEmpty) + assert(executor.container.getResources.getLimits.size() === 1) + assert(executor.container.getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.pod.getSpec.getNodeSelector.isEmpty) + assert(executor.pod.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + longPodNamePrefix, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map("qux" -> "quux"))) + val executor = step.configurePod(SparkPod.initialPod()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + ENV_CLASSPATH -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> APP_ID, + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val mapEnvs = executorPod.container.getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala similarity index 67% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 64553d25883bb..9f817d3bfc79a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -14,34 +14,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features import java.io.File -import scala.collection.JavaConverters._ - import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.util.Utils -class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private val APP_ID = "k8s-app" private var credentialsTempDirectory: File = _ - private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) + private val BASE_DRIVER_POD = SparkPod.initialPod() + + @Mock + private var driverSpecificConf: KubernetesDriverSpecificConf = _ before { + MockitoAnnotations.initMocks(this) credentialsTempDirectory = Utils.createTempDir() } @@ -50,13 +51,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA } test("Don't set any credentials") { - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + val kubernetesConf = KubernetesConf( + new SparkConf(false), + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) } test("Only set credentials that are manually mounted.") { @@ -73,14 +80,23 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() + resolvedProperties.foreach { case (propKey, propValue) => + assert(submissionSparkConf.get(propKey) === propValue) + } } test("Mount credentials from the submission client as a secret.") { @@ -100,10 +116,17 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( - BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> @@ -113,16 +136,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> DRIVER_CREDENTIALS_CLIENT_CERT_PATH, s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - DRIVER_CREDENTIALS_CA_CERT_PATH, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> - clientKeyFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> - clientCertFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - caCertFile.getAbsolutePath) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - assert(preparedDriverSpec.otherKubernetesResources.size === 1) - val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + DRIVER_CREDENTIALS_CA_CERT_PATH) + assert(resolvedProperties === expectedSparkConf) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1) + val credentialsSecret = kubernetesCredentialsStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => @@ -134,12 +154,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") assert(decodedSecretData === expectedSecretData) - val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) + val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala assert(driverPodVolumes.size === 1) assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverPodVolumes.head.getSecret != null) assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) - val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala assert(driverContainerVolumeMount.size === 1) assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala new file mode 100644 index 0000000000000..c299d56865ec0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Clock + +class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + sparkConf = sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) + assert(configurationStep.getAdditionalKubernetesResources().size === 1) + assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val resolvedService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + resolvedService) + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) + === DEFAULT_BLOCKMANAGER_PORT.toString) + } + + test("Long prefixes should switch to using a generated name.") { + when(clock.getTimeMillis()).thenReturn(10000) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: Map[String, String], expectedHostName: String): Unit = { + assert(driverSparkConf( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala new file mode 100644 index 0000000000000..27bff74ce38af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.deploy.k8s.SparkPod + +object KubernetesFeaturesTestUtils { + + def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep]( + stepType: String, stepClass: Class[T]): T = { + val mockStep = mock(stepClass) + when(mockStep.getAdditionalKubernetesResources()).thenReturn( + getSecretsForStepType(stepType)) + + when(mockStep.getAdditionalPodSystemProperties()) + .thenReturn(Map(stepType -> stepType)) + when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + .thenAnswer(new Answer[SparkPod]() { + override def answer(invocation: InvocationOnMock): SparkPod = { + val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val configuredPod = new PodBuilder(originalPod.pod) + .editOrNewMetadata() + .addToLabels(stepType, stepType) + .endMetadata() + .build() + SparkPod(configuredPod, originalPod.container) + } + }) + mockStep + } + + def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String) + : Seq[HasMetadata] = { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(stepType) + .endMetadata() + .build()) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala similarity index 64% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 960d0bda1d011..9d02f56cc206d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -14,29 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} -class DriverMountSecretsStepSuite extends SparkFunSuite { +class MountSecretsFeatureStepSuite extends SparkFunSuite { private val SECRET_FOO = "foo" private val SECRET_BAR = "bar" private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("mounts all given secrets") { - val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val baseDriverPod = SparkPod.initialPod() val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + secretNamesToMountPaths, + Map.empty) - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) - val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) - val driverPodWithSecretsMounted = configuredDriverSpec.driverPod - val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + val step = new MountSecretsFeatureStep(kubernetesConf) + val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod + val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 6a501592f42a3..c1b203e03a357 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -16,22 +16,17 @@ */ package org.apache.spark.deploy.k8s.submit -import scala.collection.JavaConverters._ - -import com.google.common.collect.Iterables import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -39,6 +34,74 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" private val KUBERNETES_RESOURCE_PREFIX = "resource-example" + private val POD_NAME = "driver" + private val CONTAINER_NAME = "container" + private val APP_ID = "app-id" + private val APP_NAME = "app" + private val MAIN_CLASS = "main" + private val APP_ARGS = Seq("arg1", "arg2") + private val RESOLVED_JAVA_OPTIONS = Map( + "conf1key" -> "conf1value", + "conf2key" -> "conf2value") + private val BUILT_DRIVER_POD = + new PodBuilder() + .withNewMetadata() + .withName(POD_NAME) + .endMetadata() + .withNewSpec() + .withHostname("localhost") + .endSpec() + .build() + private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build() + private val ADDITIONAL_RESOURCES = Seq( + new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build()) + + private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec( + SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER), + ADDITIONAL_RESOURCES, + RESOLVED_JAVA_OPTIONS) + + private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() + .build() + private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD) + .editSpec() + .addToContainers(FULL_EXPECTED_CONTAINER) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap() + .endVolume() + .endSpec() + .build() + + private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + + private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret => + new SecretBuilder(secret) + .editMetadata() + .addNewOwnerReference() + .withName(POD_NAME) + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .withController(true) + .withUid(DRIVER_POD_UID) + .endOwnerReference() + .endMetadata() + .build() + } private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -56,113 +119,86 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + @Mock + private var driverBuilder: KubernetesDriverBuilder = _ + @Mock private var resourceList: ResourceList = _ - private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ + + private var sparkConf: SparkConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ - private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( + sparkConf, + KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KUBERNETES_RESOURCE_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.withName(POD_NAME)).thenReturn(namedPods) createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) - when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { - override def answer(invocation: InvocationOnMock): Pod = { - new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) - .editMetadata() - .withUid(DRIVER_POD_UID) - .endMetadata() - .withApiVersion(DRIVER_POD_API_VERSION) - .withKind(DRIVER_POD_KIND) - .build() - } - }) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE) when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) doReturn(resourceList) .when(kubernetesClient) .resourceList(createdResourcesArgumentCaptor.capture()) } - test("The client should configure the pod with the submission steps.") { + test("The client should configure the pod using the builder.") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) - assert(createdPod.getMetadata.getLabels.asScala === - Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) - assert(createdPod.getMetadata.getAnnotations.asScala === - Map(SecondTestConfigurationStep.annotationKey -> - SecondTestConfigurationStep.annotationValue)) - assert(createdPod.getSpec.getContainers.size() === 1) - assert(createdPod.getSpec.getContainers.get(0).getName === - SecondTestConfigurationStep.containerName) + verify(podOperations).create(FULL_EXPECTED_POD) } test("The client should create Kubernetes resources") { - val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" - val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( - submissionSteps, - new SparkConf(false) - .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues assert(otherCreatedResources.size === 2) - val secrets = otherCreatedResources.toArray - .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq + assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES) val configMaps = otherCreatedResources.toArray .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) assert(secrets.nonEmpty) - val secret = secrets.head - assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(secret.getData.asScala === - Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) - assert(ownerReference.getName === createdPod.getMetadata.getName) - assert(ownerReference.getKind === DRIVER_POD_KIND) - assert(ownerReference.getUid === DRIVER_POD_UID) - assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) assert(configMaps.nonEmpty) val configMap = configMaps.head assert(configMap.getMetadata.getName === s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( - "spark.custom-conf=custom-conf-value")) - val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) - assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverEnv = driverContainer.getEnv.asScala.head - assert(driverEnv.getName === ENV_SPARK_CONF_DIR) - assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) - val driverMount = driverContainer.getVolumeMounts.asScala.head - assert(driverMount.getName === SPARK_CONF_VOLUME) - assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value")) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value")) } test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, true, "spark", @@ -171,56 +207,4 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } - -} - -private object FirstTestConfigurationStep extends DriverConfigurationStep { - - val podName = "test-pod" - val secretName = "test-secret" - val labelKey = "first-submit" - val labelValue = "true" - val secretKey = "secretKey" - val secretData = "secretData" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .withName(podName) - .addToLabels(labelKey, labelValue) - .endMetadata() - .build() - val additionalResource = new SecretBuilder() - .withNewMetadata() - .withName(secretName) - .endMetadata() - .addToData(secretKey, secretData) - .build() - driverSpec.copy( - driverPod = modifiedPod, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) - } -} - -private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" - val annotationValue = "submitted" - val sparkConfKey = "spark.custom-conf" - val sparkConfValue = "custom-conf-value" - val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .addToAnnotations(annotationKey, annotationValue) - .endMetadata() - .build() - val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) - val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(containerName) - .build() - driverSpec.copy( - driverPod = modifiedPod, - driverSparkConf = resolvedSparkConf, - driverContainer = modifiedContainer) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala deleted file mode 100644 index df34d2dbcb5be..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.steps._ - -class DriverConfigOrchestratorSuite extends SparkFunSuite { - - private val DRIVER_IMAGE = "driver-image" - private val IC_IMAGE = "init-container-image" - private val APP_ID = "spark-app-id" - private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" - private val APP_NAME = "spark" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2") - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/driver" - - test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep]) - } - - test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Option.empty, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep]) - } - - test("Submission steps with driver secrets to mount") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverMountSecretsStep]) - } - - test("Submission using client local dependencies") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - var orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - - sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") - orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - } - - private def validateStepTypes( - orchestrator: DriverConfigOrchestrator, - types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps - assert(steps.size === types.size) - assert(steps.map(_.getClass) === types) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala new file mode 100644 index 0000000000000..161f9afe7bba9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesDriverBuilderSuite extends SparkFunSuite { + + private val BASIC_STEP_TYPE = "basic" + private val CREDENTIALS_STEP_TYPE = "credentials" + private val SERVICE_STEP_TYPE = "service" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) + + private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) + + private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest: KubernetesDriverBuilder = + new KubernetesDriverBuilder( + _ => basicFeatureStep, + _ => credentialsStep, + _ => serviceStep, + _ => secretsStep) + + test("Apply fundamental steps all the time.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) + : Unit = { + assert(resolvedSpec.systemProperties.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) + assert(resolvedSpec.driverKubernetesResources.containsSlice( + KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) + assert(resolvedSpec.systemProperties(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala deleted file mode 100644 index ee450fff8d376..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class BasicDriverConfigurationStepSuite extends SparkFunSuite { - - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" - - test("Set all possible configurations from the user.") { - val sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") - .set("spark.driver.cores", "2") - .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") - - val submissionStep = new BasicDriverConfigurationStep( - APP_ID, - RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - CONTAINER_IMAGE_PULL_POLICY, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = basePod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - - assert(preparedDriverSpec.driverContainer.getEnv.size === 4) - val envs = preparedDriverSpec.driverContainer - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") - assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") - - assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => - envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && - envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && - envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) - - val resourceRequirements = preparedDriverSpec.driverContainer.getResources - val requests = resourceRequirements.getRequests.asScala - assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "456Mi") - val limits = resourceRequirements.getLimits.asScala - assert(limits("memory").getAmount === "456Mi") - assert(limits("cpu").getAmount === "4") - - val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata - assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) - val expectedAnnotations = Map( - CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, - SPARK_APP_NAME_ANNOTATION -> APP_NAME) - assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - - val driverPodSpec = preparedDriverSpec.driverPod.getSpec - assert(driverPodSpec.getRestartPolicy === "Never") - assert(driverPodSpec.getImagePullSecrets.size() === 2) - assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap - val expectedSparkConf = Map( - KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, - "spark.kubernetes.submitInDriver" -> "true") - assert(resolvedSparkConf === expectedSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala deleted file mode 100644 index ca43fc97dc991..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class DependencyResolutionStepSuite extends SparkFunSuite { - - private val SPARK_JARS = Seq( - "apps/jars/jar1.jar", - "local:///var/apps/jars/jar2.jar") - - private val SPARK_FILES = Seq( - "apps/files/file1.txt", - "local:///var/apps/files/file2.txt") - - test("Added dependencies should be resolved in Spark configuration and environment") { - val dependencyResolutionStep = new DependencyResolutionStep( - SPARK_JARS, - SPARK_FILES) - val driverPod = new PodBuilder().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = driverPod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) - assert(preparedDriverSpec.driverPod === driverPod) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet - val expectedResolvedSparkJars = Set( - "apps/jars/jar1.jar", - "/var/apps/jars/jar2.jar") - assert(resolvedSparkJars === expectedResolvedSparkJars) - val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet - val expectedResolvedSparkFiles = Set( - "apps/files/file1.txt", - "/var/apps/files/file2.txt") - assert(resolvedSparkFiles === expectedResolvedSparkFiles) - val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(driverEnv.size === 1) - assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) - val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = expectedResolvedSparkJars - assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala deleted file mode 100644 index 78c8c3ba1afbd..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.util.Clock - -class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) - - private val LONG_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) - private val DRIVER_LABELS = Map( - "label1key" -> "label1value", - "label2key" -> "label2value") - - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - - test("Headless service has a port for the driver RPC and the block manager.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - assert(resolvedDriverSpec.otherKubernetesResources.size === 1) - assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - verifyService( - 9000, - 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - driverService) - } - - test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf, - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - verifyService( - DEFAULT_DRIVER_PORT, - DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) - assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === - DEFAULT_DRIVER_PORT.toString) - assert(resolvedDriverSpec.driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) - } - - test("Long prefixes should switch to using a generated name.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - when(clock.getTimeMillis()).thenReturn(10000) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" - assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Disallow bind address and driver host to be set explicitly.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - clock) - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") - } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") - } - } - - private def verifyService( - driverPort: Int, - blockManagerPort: Int, - expectedServiceName: String, - service: Service): Unit = { - assert(service.getMetadata.getName === expectedServiceName) - assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) - assert(service.getSpec.getPorts.size() === 2) - val driverServicePorts = service.getSpec.getPorts.asScala - assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) - assert(driverServicePorts.head.getPort.intValue() === driverPort) - assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) - assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) - assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) - assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) - } - - private def verifySparkConfHostNames( - driverSparkConf: SparkConf, expectedHostName: String): Unit = { - assert(driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala deleted file mode 100644 index d73df20f0f956..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.MockitoAnnotations -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { - - private val driverPodName: String = "driver-pod" - private val driverPodUid: String = "driver-uid" - private val executorPrefix: String = "base" - private val executorImage: String = "executor-image" - private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(driverPodName) - .withUid(driverPodUid) - .endMetadata() - .withNewSpec() - .withNodeName("some-node") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private var baseConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - baseConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(CONTAINER_IMAGE, executorImage) - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set(IMAGE_PULL_SECRETS, imagePullSecrets) - } - - test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - // The executor pod name and default labels. - assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") - assert(executor.getMetadata.getLabels.size() === 3) - assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - - // There is exactly 1 container with no volume mounts and default memory limits and requests. - // Default memory limit/request is 1024M + 384M (minimum overhead constant). - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getImage === executorImage) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) - assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources - .getRequests.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getContainers.get(0).getResources - .getLimits.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getImagePullSecrets.size() === 2) - assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - // The pod has no node selector, volumes. - assert(executor.getSpec.getNodeSelector.isEmpty) - assert(executor.getSpec.getVolumes.isEmpty) - - checkEnv(executor, Map()) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor core request specification") { - var factory = new ExecutorPodFactory(baseConf, None) - var executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "1") - - val conf = baseConf.clone() - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") - factory = new ExecutorPodFactory(conf, None) - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "0.1") - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - factory = new ExecutorPodFactory(conf, None) - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "100m") - } - - test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() - conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, - "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getHostname.length === 63) - } - - test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) - - checkEnv(executor, - Map("SPARK_JAVA_OPT_0" -> "foo=bar", - ENV_CLASSPATH -> "bar=baz", - "qux" -> "quux")) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor secrets get mounted") { - val conf = baseConf.clone() - - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - // check volume mounted. - assert(executor.getSpec.getVolumes.size() === 1) - assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - // There is always exactly one controller reference, and it points to the driver pod. - private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { - assert(executor.getMetadata.getOwnerReferences.size() === 1) - assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) - assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) - } - - // Check that the expected environment variables are present. - private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { - val defaultEnvs = Map( - ENV_EXECUTOR_ID -> "1", - ENV_DRIVER_URL -> "dummy", - ENV_EXECUTOR_CORES -> "1", - ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> "dummy", - ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) - val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { - x => (x.getName, x.getValue) - }.toMap - assert(defaultEnvs === mapEnvs) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index b2f26f205a329..96065e83f069c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} -import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.hamcrest.{BaseMatcher, Description, Matcher} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito.{doNothing, never, times, verify, when} import org.scalatest.BeforeAndAfter @@ -31,6 +32,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc._ @@ -47,8 +49,6 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 private val POD_ALLOCATION_INTERVAL = "1m" - private val DRIVER_URL = RpcEndpointAddress( - SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() .withNewMetadata() .withName("pod1") @@ -94,7 +94,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var requestExecutorsService: ExecutorService = _ @Mock - private var executorPodFactory: ExecutorPodFactory = _ + private var executorBuilder: KubernetesExecutorBuilder = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -399,7 +399,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, - executorPodFactory, + executorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) { @@ -428,13 +428,22 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) .endMetadata() .build() - when(executorPodFactory.createExecutorPod( - executorId.toString, - APP_ID, - DRIVER_URL, - sparkConf.getExecutorEnv, - driverPod, - Map.empty)).thenReturn(resolvedPod) - resolvedPod + val resolvedContainer = new ContainerBuilder().build() + when(executorBuilder.buildFromFeatures(Matchers.argThat( + new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any) + : Boolean = { + argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && + argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + .roleSpecificConf.executorId == executorId.toString + } + + override def describeTo(description: Description): Unit = {} + }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) + new PodBuilder(resolvedPod) + .editSpec() + .addToContainers(resolvedContainer) + .endSpec() + .build() } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala new file mode 100644 index 0000000000000..f5270623f8acc --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesExecutorBuilderSuite extends SparkFunSuite { + private val BASIC_STEP_TYPE = "basic" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) + private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest = new KubernetesExecutorBuilder( + _ => basicFeatureStep, + _ => mountSecretsStep) + + test("Basic steps are consistently applied.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { + assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) + } + } +} From 4dfd746de3f4346ed0c2191f8523a7e6cc9f064d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 14 Apr 2018 00:22:38 +0800 Subject: [PATCH 308/593] [SPARK-23896][SQL] Improve PartitioningAwareFileIndex ## What changes were proposed in this pull request? Currently `PartitioningAwareFileIndex` accepts an optional parameter `userPartitionSchema`. If provided, it will combine the inferred partition schema with the parameter. However, 1. to get `userPartitionSchema`, we need to combine inferred partition schema with `userSpecifiedSchema` 2. to get the inferred partition schema, we have to create a temporary file index. Only after that, a final version of `PartitioningAwareFileIndex` can be created. This can be improved by passing `userSpecifiedSchema` to `PartitioningAwareFileIndex`. With the improvement, we can reduce redundant code and avoid parsing the file partition twice. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21004 from gengliangwang/PartitioningAwareFileIndex. --- .../datasources/CatalogFileIndex.scala | 2 +- .../execution/datasources/DataSource.scala | 133 ++++++++---------- .../datasources/InMemoryFileIndex.scala | 8 +- .../PartitioningAwareFileIndex.scala | 54 ++++--- .../streaming/MetadataLogFileIndex.scala | 10 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- 7 files changed, 103 insertions(+), 108 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4046396d0e614..a66a07673e25f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -85,7 +85,7 @@ class CatalogFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( - sparkSession, rootPaths, table.storage.properties, partitionSchema = None) + sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -103,24 +102,6 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param fileIndex A FileIndex that will perform partition inference - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { - val resolved = fileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } - /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -140,31 +121,26 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileStatusCache the shared cache for file statuses to speed up listing + * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to, e.g., + fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { + // The operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below - lazy val tempFileIndex = { - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) + lazy val tempFileIndex = fileIndex.getOrElse { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) } + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) + tempFileIndex.partitionSchema } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -356,13 +332,7 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) - val fileCatalog = if (userSpecifiedSchema.nonEmpty) { - val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) - new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) - } else { - tempFileCatalog - } + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -384,24 +354,23 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap( - DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray - - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) - - val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && - catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes - new CatalogFileIndex( + val index = new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, Some(index)) + (index, resultDataSchema, resultPartitionSchema) } HadoopFsRelation( @@ -552,6 +521,40 @@ case class DataSource( sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() + allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toSeq + } } object DataSource extends Logging { @@ -699,30 +702,6 @@ object DataSource extends Logging { locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } - /** - * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. - * If `checkFilesExist` is `true`, also check the file existence. - */ - private def checkAndGlobPathIfNecessary( - hadoopConf: Configuration, - path: String, - checkFilesExist: Boolean): Seq[Path] = { - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - } - /** * Called before writing into a FileFormat based data source to make sure the * supplied schema is not empty. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 318ada0ceefc5..739d1f456e3ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration * @param rootPathsSpecified the list of root table paths to scan (some of which might be * filtered out later) * @param parameters as set of options to control discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, rootPathsSpecified: Seq[Path], parameters: Map[String, String], - partitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( - sparkSession, parameters, partitionSchema, fileStatusCache) { + sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e8..cc8af7b92c454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType} * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], - userPartitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { import PartitioningAwareFileIndex.BASE_PATH_PARAM @@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - userPartitionSchema match { + val inferredPartitionSpec = PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + timeZoneId = timeZoneId) + userSpecifiedSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - typeInference = false, - basePaths = basePaths, - timeZoneId = timeZoneId) + val userPartitionSchema = + combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType, + Literal.create(row.get(i, dt), dt), + userPartitionSchema.fields(i).dataType, Option(timeZoneId)).eval() }: _*) } - PartitionSpec(userProvidedSchema, spec.partitions.map { part => + PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) case _ => - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - timeZoneId = timeZoneId) + inferredPartitionSpec } } @@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } + + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param spec A partition inference result + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { + val equality = sparkSession.sessionState.conf.resolver + val resolved = spec.partitionColumns.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 1da703cefd8ea..5cacdd070b735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. * - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class MetadataLogFileIndex( sparkSession: SparkSession, path: Path, - userPartitionSchema: Option[StructType]) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { + userSpecifiedSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -51,7 +51,7 @@ class MetadataLogFileIndex( } override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - allFilesFromLog.toArray.groupBy(_.getPath.getParent) + allFilesFromLog.groupBy(_.getPath.getParent) } override def rootPaths: Seq[Path] = path :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..8764f0c42cf9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -401,7 +401,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi sparkSession = spark, rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], - partitionSchema = None) + userSpecifiedSchema = None) // This should not fail. fileCatalog.listLeafFiles(Seq(new Path(tempDir))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 1a86c604d5da3..3af163af0968c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -419,7 +419,7 @@ class PartitionedTablePerfStatsSuite HiveCatalogMetrics.reset() spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) - assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) } } } From 25892f3cc9dcb938220be8020a5b9a17c92dbdbe Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 14 Apr 2018 01:01:00 +0800 Subject: [PATCH 309/593] [SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer ## What changes were proposed in this pull request? Added a new rule to remove Sort operation when its child is already sorted. For instance, this simple code: ``` spark.sparkContext.parallelize(Seq(("a", "b"))).toDF("a", "b").registerTempTable("table1") val df = sql(s"""SELECT b | FROM ( | SELECT a, b | FROM table1 | ORDER BY a | ) t | ORDER BY a""".stripMargin) df.explain(true) ``` before the PR produces this plan: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(3) Project [b#7] +- *(3) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(2) Project [b#7, a#6] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` while after the PR produces: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(2) Project [b#7] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 5) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` this means that an unnecessary sort operation is not performed after the PR. ## How was this patch tested? added UT Author: Marco Gaido Closes #20560 from mgaido91/SPARK-23375. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../catalyst/plans/logical/LogicalPlan.scala | 9 ++ .../plans/logical/basicLogicalOperators.scala | 23 ++-- .../optimizer/RemoveRedundantSortsSuite.scala | 101 ++++++++++++++++++ .../spark/sql/execution/CacheManager.scala | 4 +- .../spark/sql/execution/ExistingRDD.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 17 +-- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +-- 10 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9a1bbc675e397..5fb59ef350b8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -733,6 +735,16 @@ object EliminateSorts extends Rule[LogicalPlan] { } } +/** + * Removes Sort operation if the child is already sorted + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + child + } +} + /** * Removes filters that can be evaluated trivially. This can be done through the following ways: * 1) by eliding the filter for cases where it will always evaluate to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c8ccd9bd03994..42034403d6d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -219,6 +219,11 @@ abstract class LogicalPlan * Refreshes (or invalidates) any metadata/data cached in the plan recursively. */ def refresh(): Unit = children.foreach(_.refresh()) + + /** + * Returns the output ordering that this plan generates. + */ + def outputOrdering: Seq[SortOrder] = Nil } /** @@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Seq(left, right) } + +abstract class OrderPreservingUnaryNode extends UnaryNode { + override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..10df504795430 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { * This node is inserted at the top of a subquery when it is optimized. This makes sure we can * recognize a subquery as such, and it allows us to write subquery aware transformations. */ -case class Subquery(child: LogicalPlan) extends UnaryNode { +case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } -case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -125,7 +126,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) - extends UnaryNode with PredicateHelper { + extends OrderPreservingUnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -469,6 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def outputOrdering: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -522,6 +524,15 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( @@ -728,7 +739,7 @@ object Limit { * * See [[Limit]] for more information. */ -case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { @@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, child: LogicalPlan) - extends UnaryNode { + extends OrderPreservingUnaryNode { override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..2319ab8046e56 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} + +class RemoveRedundantSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.select('a).analyze + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d68aeb275afda..a8794be7280c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -99,7 +99,7 @@ class CacheManager extends Logging { sparkSession.sessionState.conf.columnBatchSize, storageLevel, sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, - planToCache.stats) + planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +148,7 @@ class CacheManager extends Logging { storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, tableName = cd.cachedRepresentation.tableName, - statsOfPlanToCache = cd.plan.stats) + logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f3555508185fe..be50a1571a2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,7 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2579046e30708..a7ba9b86a176f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -39,9 +39,9 @@ object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String], - statsOfPlanToCache: Statistics): InMemoryRelation = + logicalPlan: LogicalPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = statsOfPlanToCache) + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) } @@ -64,7 +64,8 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -76,7 +77,8 @@ case class InMemoryRelation( tableName = None)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache) + statsOfPlanToCache, + outputOrdering) override def producedAttributes: AttributeSet = outputSet @@ -159,7 +161,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { @@ -172,7 +174,8 @@ case class InMemoryRelation( tableName)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache).asInstanceOf[this.type] + statsOfPlanToCache, + outputOrdering).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index cee85ec8af04d..949505e449fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id) + val data = spark.range(0, n, 1, 1).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc7..40915a102bab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext { assert(planned.child.isInstanceOf[CollectLimitExec]) } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).reverse) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) + } + test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490f..9b7b316211d30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, - data.logicalPlan.stats) + data.logicalPlan) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) } @@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val logicalPlan = testData.select('value, 'key).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - logicalPlan.stats) + logicalPlan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -329,7 +330,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan) // Materialize the data. val expectedAnswer = data.collect() @@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, + LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) From 558f31b31c73b7e9f26f56498b54cf53997b59b8 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 13 Apr 2018 14:05:04 -0700 Subject: [PATCH 310/593] [SPARK-23963][SQL] Properly handle large number of columns in query on text-based Hive table ## What changes were proposed in this pull request? TableReader would get disproportionately slower as the number of columns in the query increased. I fixed the way TableReader was looking up metadata for each column in the row. Previously, it had been looking up this data in linked lists, accessing each linked list by an index (column number). Now it looks up this data in arrays, where indexing by column number works better. ## How was this patch tested? Manual testing All sbt unit tests python sql tests Author: Bruce Robbins Closes #21043 from bersprockets/tabreadfix. --- .../src/main/scala/org/apache/spark/sql/hive/TableReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc93..b5444a4217924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -381,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal - }.unzip + }.toArray.unzip /** * Builds specific unwrappers ahead of time according to object inspector From cbb41a0c5b01579c85f06ef42cc0585fbef216c5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 13 Apr 2018 16:31:39 -0700 Subject: [PATCH 311/593] [SPARK-23966][SS] Refactoring all checkpoint file writing logic in a common CheckpointFileManager interface ## What changes were proposed in this pull request? Checkpoint files (offset log files, state store files) in Structured Streaming must be written atomically such that no partial files are generated (would break fault-tolerance guarantees). Currently, there are 3 locations which try to do this individually, and in some cases, incorrectly. 1. HDFSOffsetMetadataLog - This uses a FileManager interface to use any implementation of `FileSystem` or `FileContext` APIs. It preferably loads `FileContext` implementation as FileContext of HDFS has atomic renames. 1. HDFSBackedStateStore (aka in-memory state store) - Writing a version.delta file - This uses FileSystem APIs only to perform a rename. This is incorrect as rename is not atomic in HDFS FileSystem implementation. - Writing a snapshot file - Same as above. #### Current problems: 1. State Store behavior is incorrect - HDFS FileSystem implementation does not have atomic rename. 1. Inflexible - Some file systems provide mechanisms other than write-to-temp-file-and-rename for writing atomically and more efficiently. For example, with S3 you can write directly to the final file and it will be made visible only when the entire file is written and closed correctly. Any failure can be made to terminate the writing without making any partial files visible in S3. The current code does not abstract out this mechanism enough that it can be customized. #### Solution: 1. Introduce a common interface that all 3 cases above can use to write checkpoint files atomically. 2. This interface must provide the necessary interfaces that allow customization of the write-and-rename mechanism. This PR does that by introducing the interface `CheckpointFileManager` and modifying `HDFSMetadataLog` and `HDFSBackedStateStore` to use the interface. Similar to earlier `FileManager`, there are implementations based on `FileSystem` and `FileContext` APIs, and the latter implementation is preferred to make it work correctly with HDFS. The key method this interface has is `createAtomic(path, overwrite)` which returns a `CancellableFSDataOutputStream` that has the method `cancel()`. All users of this method need to either call `close()` to successfully write the file, or `cancel()` in case of an error. ## How was this patch tested? New tests in `CheckpointFileManagerSuite` and slightly modified existing tests. Author: Tathagata Das Closes #21048 from tdas/SPARK-23966. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../streaming/CheckpointFileManager.scala | 349 ++++++++++++++++++ .../execution/streaming/HDFSMetadataLog.scala | 229 +----------- .../state/HDFSBackedStateStoreProvider.scala | 120 +++--- .../streaming/state/StateStore.scala | 4 +- .../CheckpointFileManagerSuite.scala | 192 ++++++++++ .../CompactibleFileStreamLogSuite.scala | 5 - .../streaming/HDFSMetadataLogSuite.scala | 116 +----- .../streaming/state/StateStoreSuite.scala | 58 ++- 9 files changed, 678 insertions(+), 402 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1c8ab9c62623e..0dc47bfe075d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -930,6 +930,13 @@ object SQLConf { .intConf .createWithDefault(100) + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .internal() + .stringConf + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala new file mode 100644 index 0000000000000..606ba250ad9d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException, OutputStream} +import java.util.{EnumSet, UUID} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs} +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * An interface to abstract out all operation related to streaming checkpoints. Most importantly, + * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a + * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and + * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations + * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or + * without overwrite. + * + * This higher-level interface above the Hadoop FileSystem is necessary because + * different implementation of FileSystem/FileContext may have different combination of operations + * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename, + * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while + * keeping the usage simple (`createAtomic` -> `close` or `cancel`). + */ +trait CheckpointFileManager { + + import org.apache.spark.sql.execution.streaming.CheckpointFileManager._ + + /** + * Create a file and make its contents available atomically after the output stream is closed. + * + * @param path Path to create + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** List the files in a path that match a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** List all the files in a path. */ + def list(path: Path): Array[FileStatus] = { + list(path, new PathFilter { override def accept(path: Path): Boolean = true }) + } + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + + /** Is the default file system this implementation is operating on the local file system. */ + def isLocal: Boolean +} + +object CheckpointFileManager extends Logging { + + /** + * Additional methods in CheckpointFileManager implementations that allows + * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename + */ + sealed trait RenameHelperMethods { self => CheckpointFileManager + /** Create a file with overwrite. */ + def createTempFile(path: Path): FSDataOutputStream + + /** + * Rename a file. + * + * @param srcPath Source path to rename + * @param dstPath Destination path to rename to + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit + } + + /** + * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used + * mainly by `CheckpointFileManager.createAtomic` to write a file atomically. + * + * @see [[CheckpointFileManager]]. + */ + abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream) + extends FSDataOutputStream(underlyingStream, null) { + /** Cancel the `underlyingStream` and ensure that the output file is not generated. */ + def cancel(): Unit + } + + /** + * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing + * to a temporary file and then renames. + */ + sealed class RenameBasedFSDataOutputStream( + fm: CheckpointFileManager with RenameHelperMethods, + finalPath: Path, + tempPath: Path, + overwriteIfPossible: Boolean) + extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) { + + def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = { + this(fm, path, generateTempPath(path), overwrite) + } + + logInfo(s"Writing atomically to $finalPath using temp file $tempPath") + @volatile private var terminated = false + + override def close(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + try { + fm.renameTempFile(tempPath, finalPath, overwriteIfPossible) + } catch { + case fe: FileAlreadyExistsException => + logWarning( + s"Failed to rename temp file $tempPath to $finalPath because file exists", fe) + if (!overwriteIfPossible) throw fe + } + logInfo(s"Renamed temp file $tempPath to $finalPath") + } finally { + terminated = true + } + } + + override def cancel(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + fm.delete(tempPath) + } catch { + case NonFatal(e) => + logWarning(s"Error cancelling write to $finalPath", e) + } finally { + terminated = true + } + } + } + + + /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */ + def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = { + val fileManagerClass = hadoopConf.get( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key) + if (fileManagerClass != null) { + return Utils.classForName(fileManagerClass) + .getConstructor(classOf[Path], classOf[Configuration]) + .newInstance(path, hadoopConf) + .asInstanceOf[CheckpointFileManager] + } + try { + // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename() + // gives atomic renames, which is what we rely on for the default implementation + // `CheckpointFileManager.createAtomic`. + new FileContextBasedCheckpointFileManager(path, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning( + "Could not use FileContext API for managing Structured Streaming checkpoint files at " + + s"$path. Using FileSystem API instead for managing log files. If the implementation " + + s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" + + s"your Structured Streaming is not guaranteed.") + new FileSystemBasedCheckpointFileManager(path, hadoopConf) + } + } + + private def generateTempPath(path: Path): Path = { + val tc = org.apache.spark.TaskContext.get + val tid = if (tc != null) ".TID" + tc.taskAttemptId else "" + new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp") + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */ +class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + protected val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + fs.create(path, true) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def exists(path: Path): Boolean = { + try + return fs.getFileStatus(path) != null + catch { + case e: FileNotFoundException => + return false + } + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + if (!overwriteIfPossible && fs.exists(dstPath)) { + throw new FileAlreadyExistsException( + s"Failed to rename $srcPath to $dstPath as destination already exists") + } + + if (!fs.rename(srcPath, dstPath)) { + // FileSystem.rename() returning false is very ambiguous as it can be for many reasons. + // This tries to make a best effort attempt to return the most appropriate exception. + if (fs.exists(dstPath)) { + if (!overwriteIfPossible) { + throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists") + } + } else if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Failed to rename as $srcPath was not found") + } else { + val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false" + logWarning(msg) + throw new IOException(msg) + } + } + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + logInfo(s"Failed to delete $path as it does not exist") + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => true + case _ => false + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */ +class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + import CreateFlag._ + import Options._ + fc.create( + path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled())) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def exists(path: Path): Boolean = { + fc.util.exists(path) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + import Options.Rename._ + fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE) + } + + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fc.getDefaultFileSystem match { + case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs + case _ => false + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 00bc215a5dc8c..bd0a46115ceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") - import HDFSMetadataLog._ - val metadataPath = new Path(path) - protected val fileManager = createFileManager() + + protected val fileManager = + CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - writeBatch(batchId, metadata) + writeBatchToFile(metadata, batchIdToPath(batchId)) true } } - private def writeTempBatch(metadata: T): Option[Path] = { - while (true) { - val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") - try { - val output = fileManager.create(tempPath) - try { - serialize(metadata, output) - return Some(tempPath) - } finally { - output.close() - } - } catch { - case e: FileAlreadyExistsException => - // Failed to create "tempPath". There are two cases: - // 1. Someone is creating "tempPath" too. - // 2. This is a restart. "tempPath" has already been created but not moved to the final - // batch file (not committed). - // - // For both cases, the batch has not yet been committed. So we can retry it. - // - // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use - // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a - // big problem because it requires the attacker must have the permission to write the - // metadata path. In addition, the old Streaming also have this issue, people can create - // malicious checkpoint files to crash a Streaming application too. - } - } - None - } - - /** - * Write a batch to a temp file then rename it to the batch file. + /** Write a batch to a temp file then rename it to the batch file. * * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, metadata: T): Unit = { - val tempPath = writeTempBatch(metadata).getOrElse( - throw new IllegalStateException(s"Unable to create temp batch file $batchId")) + private def writeBatchToFile(metadata: T, path: Path): Unit = { + val output = fileManager.createAtomic(path, overwriteIfPossible = false) try { - // Try to commit the batch - // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") - fileManager.rename(tempPath, batchIdToPath(batchId)) - - // SPARK-17475: HDFSMetadataLog should not leak CRC files - // If the underlying filesystem didn't rename the CRC file, delete it. - val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") - if (fileManager.exists(crcPath)) fileManager.delete(crcPath) + serialize(metadata, output) + output.close() } catch { case e: FileAlreadyExistsException => - // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. - // So throw an exception to tell the user this is not a valid behavior. + output.cancel() + // If next batch file already exists, then another concurrently running query has + // written it. throw new ConcurrentModificationException( - s"Multiple HDFSMetadataLog are using $path", e) - } finally { - fileManager.delete(tempPath) - } - } - - /** - * @return the deserialized metadata in a batch file, or None if file not exist. - * @throws IllegalArgumentException when path does not point to a batch file. - */ - def get(batchFile: Path): Option[T] = { - if (fileManager.exists(batchFile)) { - if (isBatchFile(batchFile)) { - get(pathToBatchId(batchFile)) - } else { - throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!") - } - } else { - None + s"Multiple streaming queries are concurrently using $path", e) + case e: Throwable => + output.cancel() + throw e } } @@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => @@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - private def createFileManager(): FileManager = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - try { - new FileContextManager(metadataPath, hadoopConf) - } catch { - case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log files at path " + - s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + - s"inconsistent under failures.") - new FileSystemManager(metadataPath, hadoopConf) - } - } - /** * Parse the log version from the given `text` -- will throw exception when the parsed version * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", @@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: object HDFSMetadataLog { - /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ - trait FileManager { - - /** List the files in a path that match a filter. */ - def list(path: Path, filter: PathFilter): Array[FileStatus] - - /** Make directory at the give path and all its parent directories as needed. */ - def mkdirs(path: Path): Unit - - /** Whether path exists */ - def exists(path: Path): Boolean - - /** Open a file for reading, or throw exception if it does not exist. */ - def open(path: Path): FSDataInputStream - - /** Create path, or throw exception if it already exists */ - def create(path: Path): FSDataOutputStream - - /** - * Atomically rename path, or throw exception if it cannot be done. - * Should throw FileNotFoundException if srcPath does not exist. - * Should throw FileAlreadyExistsException if destPath already exists. - */ - def rename(srcPath: Path, destPath: Path): Unit - - /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ - def delete(path: Path): Unit - } - - /** - * Default implementation of FileManager using newer FileContext API. - */ - class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fc = if (path.toUri.getScheme == null) { - FileContext.getFileContext(hadoopConf) - } else { - FileContext.getFileContext(path.toUri, hadoopConf) - } - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fc.util.listStatus(path, filter) - } - - override def rename(srcPath: Path, destPath: Path): Unit = { - fc.rename(srcPath, destPath) - } - - override def mkdirs(path: Path): Unit = { - fc.mkdir(path, FsPermission.getDirDefault, true) - } - - override def open(path: Path): FSDataInputStream = { - fc.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fc.create(path, EnumSet.of(CreateFlag.CREATE)) - } - - override def exists(path: Path): Boolean = { - fc.util().exists(path) - } - - override def delete(path: Path): Unit = { - try { - fc.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - - /** - * Implementation of FileManager using older FileSystem API. Note that this implementation - * cannot provide atomic renaming of paths, hence can lead to consistency issues. This - * should be used only as a backup option, when FileContextManager cannot be used. - */ - class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fs = path.getFileSystem(hadoopConf) - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fs.listStatus(path, filter) - } - - /** - * Rename a path. Note that this implementation is not atomic. - * @throws FileNotFoundException if source path does not exist. - * @throws FileAlreadyExistsException if destination path already exists. - * @throws IOException if renaming fails for some unknown reason. - */ - override def rename(srcPath: Path, destPath: Path): Unit = { - if (!fs.exists(srcPath)) { - throw new FileNotFoundException(s"Source path does not exist: $srcPath") - } - if (fs.exists(destPath)) { - throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") - } - if (!fs.rename(srcPath, destPath)) { - throw new IOException(s"Failed to rename $srcPath to $destPath") - } - } - - override def mkdirs(path: Path): Unit = { - fs.mkdirs(path, FsPermission.getDirDefault) - } - - override def open(path: Path): FSDataInputStream = { - fs.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fs.create(path, false) - } - - override def exists(path: Path): Boolean = { - fs.exists(path) - } - - override def delete(path: Path): Unit = { - try { - fs.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - /** * Verify if batchIds are continuous and between `startId` and `endId`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3f5002a4e6937..df722b953228b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.io._ import java.nio.channels.ClosedChannelException import java.util.Locale @@ -27,13 +27,16 @@ import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} @@ -87,10 +90,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case object ABORTED extends STATE private val newVersion = version + 1 - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) @volatile private var state: STATE = UPDATING - @volatile private var finalDeltaFile: Path = null + private val finalDeltaFile: Path = deltaFile(newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream) override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId @@ -103,14 +106,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val keyCopy = key.copy() val valueCopy = value.copy() mapToUpdate.put(keyCopy, valueCopy) - writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") val prevValue = mapToUpdate.remove(key) if (prevValue != null) { - writeRemoveToDeltaFile(tempDeltaFileStream, key) + writeRemoveToDeltaFile(compressedStream, key) } } @@ -126,8 +129,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit verify(state == UPDATING, "Cannot commit after already committed or aborted") try { - finalizeDeltaFile(tempDeltaFileStream) - finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + commitUpdates(newVersion, mapToUpdate, compressedStream) state = COMMITTED logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion @@ -140,23 +142,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - try { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + if (state == UPDATING) { + state = ABORTED + cancelDeltaFile(compressedStream, deltaFileStream) + } else { state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) - } - } catch { - case c: ClosedChannelException => - // This can happen when underlying file output stream has been closed before the - // compression stream. - logDebug(s"Error aborting version $newVersion into $this", c) - - case e: Exception => - logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -212,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf - fs.mkdirs(baseDir) + fm.mkdirs(baseDir) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -251,31 +244,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val loadedMaps = new mutable.HashMap[Long, MapType] private lazy val baseDir = stateStoreId.storeCheckpointLocation() - private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - /** Commit a set of updates to the store with the given new version */ - private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { - val finalDeltaFile = deltaFile(newVersion) - - // scalastyle:off - // Renaming a file atop an existing one fails on HDFS - // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). - // Hence we should either skip the rename step or delete the target file. Because deleting the - // target file will break speculation, skipping the rename step is the only choice. It's still - // semantically correct because Structured Streaming requires rerunning a batch should - // generate the same output. (SPARK-19677) - // scalastyle:on - if (fs.exists(finalDeltaFile)) { - fs.delete(tempDeltaFile, true) - } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { - throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") - } + finalizeDeltaFile(output) loadedMaps.put(newVersion, map) - finalDeltaFile } } @@ -365,7 +342,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val fileToRead = deltaFile(version) var input: DataInputStream = null val sourceStream = try { - fs.open(fileToRead) + fm.open(fileToRead) } catch { case f: FileNotFoundException => throw new IllegalStateException( @@ -412,12 +389,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val tempFile = - new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") + val targetFile = snapshotFile(version) + var rawOutput: CancellableFSDataOutputStream = null var output: DataOutputStream = null - Utils.tryWithSafeFinally { - output = compressStream(fs.create(tempFile, false)) + try { + rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true) + output = compressStream(rawOutput) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -429,16 +406,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit output.write(valueBytes) } output.writeInt(-1) - } { - if (output != null) output.close() + output.close() + } catch { + case e: Throwable => + cancelDeltaFile(compressedStream = output, rawStream = rawOutput) + throw e } - if (fs.exists(fileToWrite)) { - // Skip rename if the file is alreayd created. - fs.delete(tempFile, true) - } else if (!fs.rename(tempFile, fileToWrite)) { - throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + logInfo(s"Written snapshot file for version $version of $this at $targetFile") + } + + /** + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + private def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. } - logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -447,7 +442,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit var input: DataInputStream = null try { - input = decompressStream(fs.open(fileToRead)) + input = decompressStream(fm.open(fileToRead)) var eof = false while (!eof) { @@ -508,7 +503,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case None => // The last map is not loaded, probably some other instance is in charge } - } } catch { case NonFatal(e) => @@ -534,7 +528,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) filesToDelete.foreach { f => - fs.delete(f.path, true) + fm.delete(f.path) } logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) @@ -576,7 +570,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) + fm.list(baseDir) } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d1d9f95cb0977..7eb68c21569ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -459,7 +459,6 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - logInfo("Env is not null") val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER || env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -467,13 +466,12 @@ object StateStore extends Logging { // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. if (isDriver || _coordRef == null) { - logInfo("Getting StateStoreCoordinatorRef") + logDebug("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { - logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala new file mode 100644 index 0000000000000..fe59cb25d5005 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +abstract class CheckpointFileManagerTests extends SparkFunSuite { + + def createManager(path: Path): CheckpointFileManager + + test("mkdirs, list, createAtomic, open, delete, exists") { + withTempPath { p => + val basePath = new Path(p.getAbsolutePath) + val fm = createManager(basePath) + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create atomic without overwrite + var path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).close() + assert(fm.exists(path)) + quietly { + intercept[IOException] { + // should throw exception since file exists and overwrite is false + fm.createAtomic(path, overwriteIfPossible = false).close() + } + } + + // Create atomic with overwrite if possible + path = new Path(s"$dir/file2") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() + assert(fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() // should not throw exception + + // Open and delete + fm.open(path).close() + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + } + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} + +class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("CheckpointFileManager.create() should pick up user-specified class from conf") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key -> + classOf[CreateAtomicTestManager].getName) { + val fileManager = + CheckpointFileManager.create(new Path("/"), spark.sessionState.newHadoopConf) + assert(fileManager.isInstanceOf[CreateAtomicTestManager]) + } + } + + test("CheckpointFileManager.create() should fallback from FileContext to FileSystem") { + import CheckpointFileManagerSuiteFileSystem.scheme + spark.conf.set(s"fs.$scheme.impl", classOf[CheckpointFileManagerSuiteFileSystem].getName) + quietly { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) + + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) + } + } + } +} + +class FileContextBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileContextBasedCheckpointFileManager(path, new Configuration()) + } +} + +class FileSystemBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileSystemBasedCheckpointFileManager(path, new Configuration()) + } +} + + +/** A fake implementation to test different characteristics of CheckpointFileManager interface */ +class CreateAtomicTestManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + import CheckpointFileManager._ + + override def createAtomic(path: Path, overwrite: Boolean): CancellableFSDataOutputStream = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + } + val originalOut = super.createAtomic(path, overwrite) + + new CancellableFSDataOutputStream(originalOut) { + override def close(): Unit = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + throw new IOException("Copy failed intentionally") + } + super.close() + } + + override def cancel(): Unit = { + CreateAtomicTestManager.cancelCalledInCreateAtomic = true + originalOut.cancel() + } + } + } +} + +object CreateAtomicTestManager { + @volatile var shouldFailInCreateAtomic = false + @volatile var cancelCalledInCreateAtomic = false +} + + +/** + * CheckpointFileManagerSuiteFileSystem to test fallback of the CheckpointFileManager + * from FileContext to FileSystem API. + */ +private class CheckpointFileManagerSuiteFileSystem extends RawLocalFileSystem { + import CheckpointFileManagerSuiteFileSystem.scheme + + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +private object CheckpointFileManagerSuiteFileSystem { + val scheme = s"CheckpointFileManagerSuiteFileSystem${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 12eaf63415081..ec961a9ecb592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -22,15 +22,10 @@ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - import CompactibleFileStreamLog._ /** -- testing of `object CompactibleFileStreamLog` begins -- */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4677769c12a35..9268306ce4275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -17,46 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.io.{File, FileNotFoundException, IOException} -import java.net.URI +import java.io.File import java.util.ConcurrentModificationException import scala.language.implicitConversions -import scala.util.Random -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ -import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - private implicit def toOption[A](a: A): Option[A] = Option(a) - test("FileManager: FileContextManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileContextManager(path, new Configuration)) - } - } - - test("FileManager: FileSystemManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileSystemManager(path, new Configuration)) - } - } - test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir @@ -82,26 +58,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - spark.conf.set( - s"fs.$scheme.impl", - classOf[FakeFileSystem].getName) - withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog.add(0, "batch0")) - assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - - - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog2.get(0) === Some("batch0")) - assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) - - } - } - test("HDFSMetadataLog: purge") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -121,7 +77,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { // There should be exactly one file, called "2", in the metadata directory. // This check also tests for regressions of SPARK-17475 - val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + val allFiles = new File(metadataLog.metadataPath.toString).listFiles() + .filter(!_.getName.startsWith(".")).toSeq assert(allFiles.size == 1) assert(allFiles(0).getName() == "2") } @@ -172,7 +129,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("HDFSMetadataLog: metadata directory collision") { + testQuietly("HDFSMetadataLog: metadata directory collision") { withTempDir { temp => val waiter = new Waiter val maxBatchId = 100 @@ -206,60 +163,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - /** Basic test case for [[FileManager]] implementation. */ - private def testFileManager(basePath: Path, fm: FileManager): Unit = { - // Mkdirs - val dir = new Path(s"$basePath/dir/subdir/subsubdir") - assert(!fm.exists(dir)) - fm.mkdirs(dir) - assert(fm.exists(dir)) - fm.mkdirs(dir) - - // List - val acceptAllFilter = new PathFilter { - override def accept(path: Path): Boolean = true - } - val rejectAllFilter = new PathFilter { - override def accept(path: Path): Boolean = false - } - assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) - assert(fm.list(basePath, rejectAllFilter).length === 0) - - // Create - val path = new Path(s"$dir/file") - assert(!fm.exists(path)) - fm.create(path).close() - assert(fm.exists(path)) - intercept[IOException] { - fm.create(path) - } - - // Open and delete - fm.open(path).close() - fm.delete(path) - assert(!fm.exists(path)) - intercept[IOException] { - fm.open(path) - } - fm.delete(path) // should not throw exception - - // Rename - val path1 = new Path(s"$dir/file1") - val path2 = new Path(s"$dir/file2") - fm.create(path1).close() - assert(fm.exists(path1)) - fm.rename(path1, path2) - intercept[FileNotFoundException] { - fm.rename(path1, path2) - } - val path3 = new Path(s"$dir/file3") - fm.create(path3).close() - assert(fm.exists(path3)) - intercept[FileAlreadyExistsException] { - fm.rename(path2, path3) - } - } - test("verifyBatchIds") { import HDFSMetadataLog.verifyBatchIds verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) @@ -277,14 +180,3 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) } } - -/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ -class FakeFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } -} - -object FakeFileSystem { - val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c843b65020d8c..73f8705060402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +27,17 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +137,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getData(provider, 19) === Set("a" -> 19)) } - test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") @@ -344,7 +343,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } - test("SPARK-18342: commit fails when rename fails") { + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() val conf = new Configuration() @@ -366,7 +365,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def numTempFiles: Int = { if (deltaFileDir.exists) { - deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + deltaFileDir.listFiles.map(_.getName).count(n => n.endsWith(".tmp")) } else 0 } @@ -471,6 +470,43 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("error writing [version].delta cancels the output stream") { + + val hadoopConf = new Configuration() + hadoopConf.set( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + classOf[CreateAtomicTestManager].getName) + val remoteDir = Utils.createTempDir().getAbsolutePath + + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -720,6 +756,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] * this provider */ def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] + + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } } object StateStoreTestsHelper { From 73f28530d6f6dd8aba758ea818c456cf911a5f41 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Apr 2018 08:59:04 +0800 Subject: [PATCH 312/593] [SPARK-23979][SQL] MultiAlias should not be a CodegenFallback ## What changes were proposed in this pull request? Just found `MultiAlias` is a `CodegenFallback`. It should not be as looks like `MultiAlias` won't be evaluated. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21065 from viirya/multialias-without-codegenfallback. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a65f58fa61ff4..71e23175168e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode @@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression with CodegenFallback { + extends UnaryExpression with NamedExpression with Unevaluable { override def name: String = throw new UnresolvedException(this, "name") From c0964935d614bf345535439bce01cbd0e60c86aa Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Mon, 16 Apr 2018 12:01:42 +0800 Subject: [PATCH 313/593] [SPARK-23956][YARN] Use effective RPC port in AM registration ## What changes were proposed in this pull request? We propose not to hard-code the RPC port in the AM registration. ## How was this patch tested? Tested application reports from a pseudo-distributed cluster ``` 18/04/10 14:56:21 INFO Client: client token: N/A diagnostics: N/A ApplicationMaster host: localhost ApplicationMaster RPC port: 58338 queue: default start time: 1523397373659 final status: UNDEFINED tracking URL: http://localhost:8088/proxy/application_1523370127531_0016/ ``` Author: Gera Shegalov Closes #21047 from gerashegalov/gera/am-to-rm-nmhost. --- .../scala/org/apache/spark/deploy/yarn/YarnRMClient.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index c1ae12aabb8cc..17234b120ae13 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -71,7 +70,8 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, + trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, From 69310220319163bac18c9ee69d7da6d92227253b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 15 Apr 2018 21:45:55 -0700 Subject: [PATCH 314/593] [SPARK-23917][SQL] Add array_max function ## What changes were proposed in this pull request? The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21024 from mgaido91/SPARK-23917. --- python/pyspark/sql/functions.py | 15 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 68 ++++++++++++++++++- .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 133 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..f3492ae42639c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 131b958239e41..05bfa2dd45340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9212c3de1f814..942dfd4292610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0abfc9fa4c465..c86c5beded9d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..e2614a179aad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..a2384019533b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..daf407926dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..5d5d92c84df6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 083cf223569b7896e35ff1d53a73498a4971b28d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 16 Apr 2018 23:50:50 +0800 Subject: [PATCH 315/593] [SPARK-21033][CORE][FOLLOW-UP] Update Spillable ## What changes were proposed in this pull request? Update ```scala SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) ``` to ```scala SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) ``` because of `SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD`'s default value is `Integer.MAX_VALUE`: https://github.com/apache/spark/blob/c99fc9ad9b600095baba003053dbf84304ca392b/core/src/main/scala/org/apache/spark/internal/config/package.scala#L503-L511 ## How was this patch tested? N/A Author: Yuming Wang Closes #21077 from wangyum/SPARK-21033. --- .../org/apache/spark/util/collection/Spillable.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 8183f825592c0..81457b53cd814 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** @@ -41,7 +42,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - protected def elementsRead: Long = _elementsRead + protected def elementsRead: Int = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -54,15 +55,15 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // Force this collection to spill when there are this many elements in memory // For testing only - private[this] val numElementsForceSpillThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + private[this] val numElementsForceSpillThreshold: Int = + SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 @volatile private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill - private[this] var _elementsRead = 0L + private[this] var _elementsRead = 0 // Number of bytes spilled in total @volatile private[this] var _memoryBytesSpilled = 0L From 5003736ad60c3231bb18264c9561646c08379170 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 16 Apr 2018 11:27:30 -0500 Subject: [PATCH 316/593] [SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel add RawPrediction as output column add numClasses and numFeatures to OneVsRestModel ## What changes were proposed in this pull request? - Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future - Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21044 from ludatabricks/SPARK-9312. --- .../spark/ml/classification/OneVsRest.scala | 56 +++++++++++++++---- .../ml/classification/OneVsRestSuite.scala | 7 ++- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f04fde2cbbca1..5348d882cfd67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams +private[ml] trait OneVsRestParams extends ClassifierParams with ClassifierTypeTrait with HasWeightCol { /** @@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + require(models.nonEmpty, "OneVsRestModel requires at least one model for one class") + + @Since("2.4.0") + val numClasses: Int = models.length + + @Since("2.4.0") + val numFeatures: Int = models.head.numFeatures + /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset @@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble - } + if (getRawPredictionCol != "") { + val numClass = models.length - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + // output the RawPrediction as vector + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val predArray = Array.fill[Double](numClass)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } + + // output confidence as raw prediction, label and label metadata as prediction + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output label and label metadata as prediction + aggregatedDataset + .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } } @Since("1.4.1") @@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + /** * The implementation of parallel one vs. rest runs the classification for * each class in a separate threads. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 11e88367108b4..2c3417c7e4028 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") + assert(ova.getRawPredictionCol === "rawPrediction") val ovaModel = ova.fit(dataset) MLTestingUtils.checkCopyAndUids(ova, ovaModel) - assert(ovaModel.models.length === numClasses) + assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") + ovaModel.setRawPredictionCol("") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet assert(outputFields === Set("y", "fea", "pred")) @@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet + === Set("label", "features", "prediction", "rawPrediction")) } test("SPARK-21306: OneVsRest should support setWeightCol") { From 04614820e103feeae91299dc90dba1dd628fd485 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 16 Apr 2018 11:31:24 -0500 Subject: [PATCH 317/593] [SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API ## What changes were proposed in this pull request? Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting. ## How was this patch tested? UT added. Author: WeichenXu Closes #19627 from WeichenXu123/expose-model-list-py. --- .../spark/ml/tuning/CrossValidator.scala | 11 ++ .../ml/tuning/TrainValidationSplit.scala | 11 ++ .../ml/param/_shared_params_code_gen.py | 5 + python/pyspark/ml/param/shared.py | 24 ++++ python/pyspark/ml/tests.py | 78 +++++++++++++ python/pyspark/ml/tuning.py | 107 +++++++++++++----- python/pyspark/ml/util.py | 4 + 7 files changed, 211 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a0b507d2e718c..c2826dcc08634 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[JList[Model[_]]]) + : CrossValidatorModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray.map(_.asScala.toArray)) + } else { + None + } + this + } + /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 88ff0dfd75e96..8d1b9a8ddab59 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[Model[_]]) + : TrainValidationSplitModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray) + } else { + None + } + this + } + /** * @return submodels represented in array. The index of array corresponds to the ordering of * estimatorParamMaps diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index db951d81de1e7..6e9e0a34cdfde 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -157,6 +157,11 @@ def get$Name(self): "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", "1", "TypeConverters.toInt"), + ("collectSubModels", "Param for whether to collect a list of sub-models trained during " + + "tuning. If set to false, then only the single best sub-model will be available after " + + "fitting. If set to true, then all sub-models will be available. Warning: For large " + + "models, collecting all sub-models can cause OOMs on the Spark driver.", + "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 474c38764e5a1..08408ee8fbfcc 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -655,6 +655,30 @@ def getParallelism(self): return self.getOrDefault(self.parallelism) +class HasCollectSubModels(Params): + """ + Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. + """ + + collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean) + + def __init__(self): + super(HasCollectSubModels, self).__init__() + self._setDefault(collectSubModels=False) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + + def getCollectSubModels(self): + """ + Gets the value of collectSubModels or its default value. + """ + return self.getOrDefault(self.collectSubModels) + + class HasLoss(Params): """ Mixin for param loss: the loss function to be optimized. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4ce54547eab09..2ec0be60e9fa9 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1018,6 +1018,50 @@ def test_parallel_evaluation(self): cvParallelModel = cv.fit(dataset) self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -1186,6 +1230,40 @@ def test_parallel_evaluation(self): tvsParallelModel = tvs.fit(dataset) self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 545e24ca05aa5..0c8029f293cfe 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasParallelism, HasSeed +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -33,7 +33,7 @@ 'TrainValidationSplitModel'] -def _parallelFitTasks(est, train, eva, validation, epm): +def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm): :param eva: Evaluator, used to compute `metric` :param validation: DataFrame, validation data set, used for evaluation. :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. - :return: (int, float), an index into `epm` and the associated metric value. + :param collectSubModel: Whether to collect sub model. + :return: (int, float, subModel), an index into `epm` and the associated metric value. """ modelIter = est.fitMultiple(train, epm) def singleTask(): index, model = next(modelIter) metric = eva.evaluate(model.transform(validation, epm[index])) - return index, metric + return index, metric, model if collectSubModel else None return [singleTask] * len(epm) @@ -194,7 +195,8 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1) + seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) @@ -246,10 +248,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -282,6 +284,10 @@ def _fit(self, dataset): metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h @@ -290,9 +296,12 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) + if collectSubModelsParam: + subModels[i][j] = subModel + validation.unpersist() train.unpersist() @@ -301,7 +310,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics)) + return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels)) @since("1.4.0") def copy(self, extra=None): @@ -345,9 +354,11 @@ def _from_java(cls, java_stage): numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed, parallelism=parallelism) + numFolds=numFolds, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -367,6 +378,7 @@ def _to_java(self): _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[]): + def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics + #: sub model list from cross validation + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -399,6 +413,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -407,7 +422,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics - return CrossValidatorModel(bestModel, avgMetrics) + subModels = self.subModels + return CrossValidatorModel(bestModel, avgMetrics, subModels) @since("2.3.0") def write(self): @@ -426,13 +442,17 @@ def _from_java(cls, java_stage): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ - bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [[JavaParams._from_java(sub_model) + for sub_model in fold_sub_models] + for fold_sub_models in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -454,10 +474,16 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels] + _java_obj.setSubModels(java_sub_models) return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ .. note:: Experimental @@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None) + parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) @@ -505,10 +531,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -541,11 +567,19 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [None for i in range(numModels)] + + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric + if collectSubModelsParam: + subModels[j] = subModel + train.unpersist() validation.unpersist() @@ -554,7 +588,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) @since("2.0.0") def copy(self, extra=None): @@ -598,9 +632,11 @@ def _from_java(cls, java_stage): trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed, parallelism=parallelism) + trainRatio=trainRatio, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -620,7 +656,7 @@ def _to_java(self): _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) - + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[]): + def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() - #: best model from cross validation + #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics + #: sub models from train validation split + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -651,6 +689,7 @@ def copy(self, extra=None): creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -659,7 +698,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) - return TrainValidationSplitModel(bestModel, validationMetrics) + subModels = self.subModels + return TrainValidationSplitModel(bestModel, validationMetrics, subModels) @since("2.3.0") def write(self): @@ -687,6 +727,10 @@ def _from_java(cls, java_stage): py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [JavaParams._from_java(sub_model) + for sub_model in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -708,6 +752,11 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + _java_obj.setSubModels(java_sub_models) + return _java_obj diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3c47bd79459a..a486c6a3fdeb5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -169,6 +169,10 @@ def overwrite(self): self._jwrite.overwrite() return self + def option(self, key, value): + self._jwrite.option(key, value) + return self + def context(self, sqlContext): """ Sets the SQL context to use for saving. From fd990a908b94d1c90c4ca604604f35a13b453d44 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Apr 2018 22:45:57 +0200 Subject: [PATCH 318/593] [SPARK-23873][SQL] Use accessors in interpreted LambdaVariable ## What changes were proposed in this pull request? Currently, interpreted execution of `LambdaVariable` just uses `InternalRow.get` to access element. We should use specified accessors if possible. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20981 from viirya/SPARK-23873. --- .../spark/sql/catalyst/InternalRow.scala | 26 ++++++++++++- .../catalyst/expressions/BoundAttribute.scala | 22 ++--------- .../expressions/objects/objects.scala | 8 +++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 38 ++++++++++++++++++- 5 files changed, 75 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 29110640d64f2..274d75e680f03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -119,4 +119,28 @@ object InternalRow { case v: MapData => v.copy() case _ => value } + + /** + * Returns an accessor for an `InternalRow` with given data type. The returned accessor + * actually takes a `SpecializedGetters` input because it can be generalized to other classes + * that implements `SpecializedGetters` (e.g., `ArrayData`) too. + */ + def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType) + case _ => (input, ordinal) => input.get(ordinal, dataType) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5021a567592e0..4cc84b27d9eb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (input.isNullAt(ordinal)) { + if (nullable && input.isNullAt(ordinal)) { null } else { - dataType match { - case BooleanType => input.getBoolean(ordinal) - case ByteType => input.getByte(ordinal) - case ShortType => input.getShort(ordinal) - case IntegerType | DateType => input.getInt(ordinal) - case LongType | TimestampType => input.getLong(ordinal) - case FloatType => input.getFloat(ordinal) - case DoubleType => input.getDouble(ordinal) - case StringType => input.getUTF8String(ordinal) - case BinaryType => input.getBinary(ordinal) - case CalendarIntervalType => input.getInterval(ordinal) - case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => input.getStruct(ordinal, t.size) - case _: ArrayType => input.getArray(ordinal) - case _: MapType => input.getMap(ordinal) - case _ => input.get(ordinal, dataType) - } + accessor(input, ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e90ca550807..77802e89e942b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -560,11 +560,17 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - input.get(0, dataType) + if (nullable && input.isNullAt(0)) { + null + } else { + accessor(input, 0) + } } override def genCode(ctx: CodegenContext): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..b4bf6d7107d7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], MapData and Row. */ - protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b1bc67dfac1b5..b0188b0098def 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util._ @@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("LambdaVariable should support interpreted execution") { + def genSchema(dt: DataType): Seq[StructType] = { + Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), + StructType(StructField("col_1", dt, nullable = true) :: Nil)) + } + + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val mapTypes = elementTypes.flatMap { elementType => + Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true)) + } + val structTypes = elementTypes.flatMap { elementType => + Seq(StructType(StructField("col1", elementType, false) :: Nil), + StructType(StructField("col1", elementType, true) :: Nil)) + } + + val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes + val random = new Random(100) + testTypes.foreach { dt => + genSchema(dt).map { schema => + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) + } + } + } } class TestBean extends Serializable { From 14844a62c025e7299029d7452b8c4003bc221ac8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 17:55:35 +0900 Subject: [PATCH 319/593] [SPARK-23918][SQL] Add array_min function ## What changes were proposed in this pull request? The PR adds the SQL function `array_min`. It takes an array as argument and returns the minimum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21025 from mgaido91/SPARK-23918. --- python/pyspark/sql/functions.py | 17 ++++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 64 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 131 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f3492ae42639c..6ca22b610843d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_min(col): + """ + Collection function: returns the minimum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_min(_to_java_column(col))) + + @since(2.4) def array_max(col): """ @@ -2108,7 +2123,7 @@ def sort_array(col, asc=True): [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] - """ + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 05bfa2dd45340..4dd1ca509bf2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 942dfd4292610..d4e322d23b95b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -595,11 +595,7 @@ case class Least(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfSmaller(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c86c5beded9d0..d97611c98ac91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is smaller than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, partialResult.value, item.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code for updating `partialResult` if `item` is greater than it. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e2614a179aad8..7c87777eed47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -288,6 +288,70 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} /** * Returns the maximum value in the array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2384019533b7..5a31e3a30edd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Array Min") { + checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) + checkEvaluation( + ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "") + checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234) + } + test("Array max") { checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index daf407926dca4..642ac056bb809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the minimum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } + /** * Returns the maximum value in the array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d5d92c84df6d..636e86baedf6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_min function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(1), Row(null), Row(null), Row(-100)) + + checkAnswer(df.select(array_min(df("a"))), answer) + checkAnswer(df.selectExpr("array_min(a)"), answer) + } + test("array_max function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From 1cc66a072b7fd3bf140fa41596f6b18f8d1bd7b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Apr 2018 01:59:38 -0700 Subject: [PATCH 320/593] [SPARK-23687][SS] Add a memory source for continuous processing. ## What changes were proposed in this pull request? Add a memory source for continuous processing. Note that only one of the ContinuousSuite tests is migrated to minimize the diff here. I'll submit a second PR for SPARK-23688 to change the rest and get rid of waitForRateSourceTriggers. ## How was this patch tested? unit test Author: Jose Torres Closes #20828 from jose-torres/continuousMemory. --- .../continuous/ContinuousExecution.scala | 5 +- .../sql/execution/streaming/memory.scala | 59 +++-- .../sources/ContinuousMemoryStream.scala | 211 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 31 +-- 5 files changed, 266 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1758b3844bd62..951d694355ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -317,8 +318,10 @@ class ContinuousExecution( synchronized { if (queryExecutionThread.isAlive) { commitLog.add(epoch) - val offset = offsetLog.get(epoch).get.offsets(0).get + val offset = + continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) + continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) } else { return } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 352d4ce9fbcaa..628923d367ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -47,16 +49,43 @@ object MemoryStream { new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) } +/** + * A base class for memory stream implementations. Supports adding data and resetting. + */ +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { + protected val encoder = encoderFor[A] + protected val attributes = encoder.schema.toAttributes + + def toDS(): Dataset[A] = { + Dataset[A](sqlContext.sparkSession, logicalPlan) + } + + def toDF(): DataFrame = { + Dataset.ofRows(sqlContext.sparkSession, logicalPlan) + } + + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + + def readSchema(): StructType = encoder.schema + + protected def logicalPlan: LogicalPlan + + def addData(data: TraversableOnce[A]): Offset +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - protected val encoder = encoderFor[A] - private val attributes = encoder.schema.toAttributes - protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) + extends MemoryStreamBase[A](sqlContext) + with MicroBatchReader with SupportsScanUnsafeRow with Logging { + + protected val logicalPlan: LogicalPlan = + StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected var currentOffset: LongOffset = new LongOffset(-1) @GuardedBy("this") - private var startOffset = new LongOffset(-1) + protected var startOffset = new LongOffset(-1) @GuardedBy("this") private var endOffset = new LongOffset(-1) @@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def toDS(): Dataset[A] = { - Dataset(sqlContext.sparkSession, logicalPlan) - } - - def toDF(): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, logicalPlan) - } - - def addData(data: A*): Offset = { - addData(data.toTraversable) - } - def addData(data: TraversableOnce[A]): Offset = { val objects = data.toSeq val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray @@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def readSchema(): StructType = encoder.schema - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) override def getStartOffset: OffsetV2 = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala new file mode 100644 index 0000000000000..c28919b8b729b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.{Encoder, Row, SQLContext} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + +/** + * The overall strategy here is: + * * ContinuousMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * RecordEndpoint is set up as an endpoint for executor-side + * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified + * offset within the list, or null if that offset doesn't yet have a record. + */ +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) + private val NUM_PARTITIONS = 2 + + protected val logicalPlan = + StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + + // ContinuousReader implementation + + @GuardedBy("this") + private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + + private val recordEndpoint = new RecordEndpoint() + @volatile private var endpointRef: RpcEndpointRef = _ + + def addData(data: TraversableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.toSeq.zipWithIndex.map { + case (item, index) => records(index % NUM_PARTITIONS) += item + } + + // The new target offset is the offset where all records in all partitions have been processed. + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + } + + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset + } + + override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset( + offsets.map { + case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + }.toMap + ) + } + + override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + synchronized { + val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = + recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + new ContinuousMemoryStreamDataReaderFactory( + endpointName, part, index): DataReaderFactory[Row] + }.toList.asJava + } + } + + override def stop(): Unit = { + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + override def commit(end: Offset): Unit = {} + + // ContinuousReadSupport implementation + // This is necessary because of how StreamTest finds the source for AddDataMemory steps. + def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + this + } + + /** + * Endpoint for executors to poll for records. + */ + private class RecordEndpoint extends ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => + ContinuousMemoryStream.this.synchronized { + val buf = records(part) + val record = if (buf.size <= index) None else Some(buf(index)) + + context.reply(record.map(Row(_))) + } + } + } +} + +object ContinuousMemoryStream { + case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * Data reader factory for continuous memory stream. + */ +class ContinuousMemoryStreamDataReaderFactory( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends DataReaderFactory[Row] { + override def createDataReader: ContinuousMemoryStreamDataReader = + new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) +} + +/** + * Data reader for continuous memory stream. + * + * Polls the driver endpoint for new records. + */ +class ContinuousMemoryStreamDataReader( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends ContinuousDataReader[Row] { + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[Row] = None + + override def next(): Boolean = { + current = None + while (current.isEmpty) { + Thread.sleep(10) + current = endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + } + currentOffset += 1 + true + } + + override def get(): Row = current.get + + override def close(): Unit = {} + + override def getOffset: ContinuousMemoryStreamPartitionOffset = + ContinuousMemoryStreamPartitionOffset(partition, currentOffset) +} + +case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) + extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(partitionNums) +} + +case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) + extends PartitionOffset diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 00741d660dd2d..af0268fa47871 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * been processed. */ object AddData { - def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] = AddDataMemory(source, data) } @@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def runAction(): Unit } - case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index ef74efef156d5..c318b951ff992 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -53,32 +54,24 @@ class ContinuousSuiteBase extends StreamTest { // A continuous trigger that will only fire the initial time for the duration of a test. // This allows clean testing with manual epoch advancement. protected val longContinuousTrigger = Trigger.Continuous("1 hour") + + override protected val defaultTrigger = Trigger.Continuous(100) + override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { import testImplicits._ - test("basic rate source") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + test("basic") { + val input = ContinuousMemoryStream[Int] - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(input.toDF())( + AddData(input, 0, 1, 2), + CheckAnswer(0, 1, 2), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), - StopStream) + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(0, 1, 2, 3, 4, 5)) } test("map") { From 05ae74778a10fbdd7f2cbf7742de7855966b7d35 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Tue, 17 Apr 2018 04:13:17 -0700 Subject: [PATCH 321/593] [SPARK-23747][STRUCTURED STREAMING] Add EpochCoordinator unit tests ## What changes were proposed in this pull request? Unit tests for EpochCoordinator that test correct sequencing of committed epochs. Several tests are ignored since they test functionality implemented in SPARK-23503 which is not yet merged, otherwise they fail. Author: Efim Poberezkin Closes #20983 from efimpoberezkin/pr/EpochCoordinator-tests. --- .../continuous/EpochCoordinatorSuite.scala | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala new file mode 100644 index 0000000000000..99e30561f81d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.mockito.InOrder +import org.mockito.Matchers.{any, eq => eqTo} +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.test.TestSparkSession + +class EpochCoordinatorSuite + extends SparkFunSuite + with LocalSparkSession + with MockitoSugar + with BeforeAndAfterEach { + + private var epochCoordinator: RpcEndpointRef = _ + + private var writer: StreamWriter = _ + private var query: ContinuousExecution = _ + private var orderVerifier: InOrder = _ + + override def beforeEach(): Unit = { + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] + query = mock[ContinuousExecution] + orderVerifier = inOrder(writer, query) + + spark = new TestSparkSession() + + epochCoordinator + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + } + + test("single epoch") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + // Here and in subsequent tests this is called to make a synchronous call to EpochCoordinator + // so that mocks would have been acted upon by the time verification happens + makeSynchronousCall() + + verifyCommit(1) + } + + test("single epoch, all but one writer partition has committed") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("single epoch, all but one reader partition has reported an offset") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("consequent epochs, messages for epoch (k + 1) arrive after messages for epoch k") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + // Message that arrives late + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 5) + reportPartitionOffset(0, 5) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) + } + + private def setWriterPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) + } + + private def setReaderPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetReaderPartitions(numPartitions)) + } + + private def commitPartitionEpoch(partitionId: Int, epoch: Long): Unit = { + val dummyMessage: WriterCommitMessage = mock[WriterCommitMessage] + epochCoordinator.send(CommitPartitionEpoch(partitionId, epoch, dummyMessage)) + } + + private def reportPartitionOffset(partitionId: Int, epoch: Long): Unit = { + val dummyOffset: PartitionOffset = mock[PartitionOffset] + epochCoordinator.send(ReportPartitionOffset(partitionId, epoch, dummyOffset)) + } + + private def makeSynchronousCall(): Unit = { + epochCoordinator.askSync[Long](GetCurrentEpoch) + } + + private def verifyCommit(epoch: Long): Unit = { + orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(query).commit(epoch) + } + + private def verifyNoCommitFor(epoch: Long): Unit = { + verify(writer, never()).commit(eqTo(epoch), any()) + verify(query, never()).commit(epoch) + } + + private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { + epochs.foreach(verifyCommit) + } +} From 30ffb53cad84283b4f7694bfd60bdd7e1101b04e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 17 Apr 2018 15:09:36 +0200 Subject: [PATCH 322/593] [SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20984 from viirya/SPARK-23875. --- .../expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/util/ArrayData.scala | 30 +++++- .../util/ArrayDataIndexedSeqSuite.scala | 100 ++++++++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 77802e89e942b..72b202b3a5020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -708,7 +708,7 @@ case class MapObjects private( } } case ArrayType(et, _) => - _.asInstanceOf[ArrayData].array + _.asInstanceOf[ArrayData].toSeq[Any](et) } private lazy val mapElements: Seq[_] => Any = customCollectionCls match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41d639f3..2cf59d567c08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def toSeq[T](dataType: DataType): IndexedSeq[T] = + new ArrayDataIndexedSeq[T](this, dataType) + def setNullAt(i: Int): Unit def update(i: Int, value: Any): Unit @@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable { } } } + +/** + * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData` + * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`, + * instead of `IndexedSeq[Int]`, in order to keep the null elements. + */ +class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] { + + private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType) + + override def apply(idx: Int): T = + if (0 <= idx && idx < arrayData.numElements()) { + if (arrayData.isNullAt(idx)) { + null.asInstanceOf[T] + } else { + accessor(arrayData, idx).asInstanceOf[T] + } + } else { + throw new IndexOutOfBoundsException( + s"Index $idx must be between 0 and the length of the ArrayData.") + } + + override def length: Int = arrayData.numElements() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala new file mode 100644 index 0000000000000..6400898343ae7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.types._ + +class ArrayDataIndexedSeqSuite extends SparkFunSuite { + private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = { + assert(arrayData.numElements == array.length) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For NaN, etc. + case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e)) + case _ => assert(arrayData.get(i, elementDt) === e) + } + } else { + assert(arrayData.isNullAt(i)) + } + } + + val seq = arrayData.toSeq[Any](elementDt) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For Nan, etc. + case FloatType | DoubleType => assert(seq(i).equals(e)) + case _ => assert(seq(i) === e) + } + } else { + assert(seq(i) == null) + } + } + + intercept[IndexOutOfBoundsException] { + seq(-1) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + + intercept[IndexOutOfBoundsException] { + seq(seq.length) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + } + + private def testArrayData(): Unit = { + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val random = new Random(100) + arrayTypes.foreach { dt => + val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + + val unsafeRowConverter = UnsafeProjection.create(schema) + val safeRowConverter = FromUnsafeProjection(schema) + + val unsafeRow = unsafeRowConverter(internalRow) + val safeRow = safeRowConverter(unsafeRow) + + val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData] + val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + + val elementType = dt.elementType + test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) { + compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType)) + } + + test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) { + compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType)) + } + } + } + + testArrayData() +} From 0a9172a05e604a4a94adbb9208c8c02362afca00 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 21:45:20 +0800 Subject: [PATCH 323/593] [SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization ## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido Closes #20976 from mgaido91/SPARK-23835. --- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 14 +++++++------- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index fc890a0cfdac3..ddfc0c1a4be2d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 42f8b4c7657e2..7079ac6453ffc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1aae3aea3a31a..e4274aaa9727e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { + val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = deserializerFor( + deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 8c3db48a01f12..353b8344658f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("_2", NullType, nullable = true))), nullable = true)) } + + test("SPARK-23835: add null check to non-nullable types in Tuples") { + def numberOfCheckedArguments(deserializer: Expression): Int = { + assert(deserializer.isInstanceOf[NewInstance]) + deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + } + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9b745befcb611..e0f4d2ba685e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val group2 = cached.groupBy("x").agg(min(col("z")) as "value") checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) } + + test("SPARK-23835: null primitive data type should throw NullPointerException") { + val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() + intercept[NullPointerException](ds.as[(Int, Int)].collect()) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From ed4101d29f50d54fd7846421e4c00e9ecd3599d0 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 21:52:33 +0800 Subject: [PATCH 324/593] [SPARK-22676] Avoid iterating all partition paths when spark.sql.hive.verifyPartitionPath=true MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code, it will scanning all partition paths when spark.sql.hive.verifyPartitionPath=true. e.g. table like below: ``` CREATE TABLE `test`( `id` int, `age` int, `name` string) PARTITIONED BY ( `A` string, `B` string) load data local inpath '/tmp/data0' into table test partition(A='00', B='00') load data local inpath '/tmp/data1' into table test partition(A='01', B='01') load data local inpath '/tmp/data2' into table test partition(A='10', B='10') load data local inpath '/tmp/data3' into table test partition(A='11', B='11') ``` If I query with SQL – "select * from test where A='00' and B='01' ", current code will scan all partition paths including '/data/A=00/B=00', '/data/A=00/B=00', '/data/A=01/B=01', '/data/A=10/B=10', '/data/A=11/B=11'. It costs much time and memory cost. This pr proposes to avoid iterating all partition paths. Add a config `spark.files.ignoreMissingFiles` and ignore the `file not found` when `getPartitions/compute`(for hive table scan). This is much like the logic brought by `spark.sql.files.ignoreMissingFiles`(which is for datasource scan). ## How was this patch tested? UT Author: jinxing Closes #19868 from jinxing64/SPARK-22676. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 43 +++++++++--- .../org/apache/spark/rdd/NewHadoopRDD.scala | 45 ++++++++---- .../scala/org/apache/spark/FileSuite.scala | 69 ++++++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/hive/QueryPartitionSuite.scala | 40 +++++++++++ 6 files changed, 181 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 407545aa4a47a..99d779fb600e8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -301,6 +301,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val IGNORE_MISSING_FILES = ConfigBuilder("spark.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") .stringConf .createOptional diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2480559a41b7a..44895abc7bd4d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred._ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -134,6 +135,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. @@ -197,17 +200,24 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) - val inputSplits = if (ignoreEmptySplits) { - allInputSplits.filter(_.getLength > 0) - } else { - allInputSplits - } - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) + try { + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } + val array = new Array[Partition](inputSplits.size) + for (i <- 0 until inputSplits.size) { + array(i) = new HadoopPartition(id, i, inputSplits(i)) + } + array + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - array } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -256,6 +266,12 @@ class HadoopRDD[K, V]( try { inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true @@ -276,6 +292,11 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e4dd1b6a82498..ff66a04859d10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,7 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileInputFormat, FileSplit, InvalidInputException} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark._ @@ -90,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { @@ -124,17 +126,25 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala - val rawSplits = if (ignoreEmptySplits) { - allRowSplits.filter(_.getLength > 0) - } else { - allRowSplits - } - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + try { + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${_conf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - result } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -189,6 +199,12 @@ class NewHadoopRDD[K, V]( _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) _reader } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", @@ -213,6 +229,11 @@ class NewHadoopRDD[K, V]( try { finished = !reader.nextKeyValue } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 55a9122cf9026..a441b9c8ab97a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec @@ -32,7 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.internal.config._ -import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -596,4 +597,70 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { actualPartitionNum = 5, expectedPartitionNum = 2) } + + test("spark.files.ignoreMissingFiles should work both HadoopRDD and NewHadoopRDD") { + // "file not found" can happen both when getPartitions or compute in HadoopRDD/NewHadoopRDD, + // We test both cases here. + + val deletedPath = new Path(tempDir.getAbsolutePath, "test-data-1") + val fs = deletedPath.getFileSystem(new Configuration()) + fs.delete(deletedPath, true) + intercept[FileNotFoundException](fs.open(deletedPath)) + + def collectRDDAndDeleteFileBeforeCompute(newApi: Boolean): Array[_] = { + val dataPath = new Path(tempDir.getAbsolutePath, "test-data-2") + val writer = new OutputStreamWriter(new FileOutputStream(new File(dataPath.toString))) + writer.write("hello\n") + writer.write("world\n") + writer.close() + val rdd = if (newApi) { + sc.newAPIHadoopFile(dataPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]) + } else { + sc.textFile(dataPath.toString) + } + rdd.partitions + fs.delete(dataPath, true) + // Exception happens when initialize record reader in HadoopRDD/NewHadoopRDD.compute + // because partitions' info already cached. + rdd.collect() + } + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=false by default. + sc = new SparkContext("local", "test") + intercept[org.apache.hadoop.mapred.InvalidInputException] { + // Exception happens when HadoopRDD.getPartitions + sc.textFile(deletedPath.toString).collect() + } + + var e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(false) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + intercept[org.apache.hadoop.mapreduce.lib.input.InvalidInputException] { + // Exception happens when NewHadoopRDD.getPartitions + sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect + } + + e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(true) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + sc.stop() + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=true. + val conf = new SparkConf().set(IGNORE_MISSING_FILES, true) + sc = new SparkContext("local", "test", conf) + assert(sc.textFile(deletedPath.toString).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(false).isEmpty) + + assert(sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0dc47bfe075d0..3729bd5293eca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -437,7 +437,8 @@ object SQLConf { val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + - "when reading data stored in HDFS.") + "when reading data stored in HDFS. This configuration will be deprecated in the future " + + "releases and replaced by spark.files.ignoreMissingFiles.") .booleanConf .createWithDefault(false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index b2dc401ce1efc..78156b17fb43b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem +import org.apache.spark.internal.config._ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -70,6 +71,45 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl } } + test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + + sql("DROP TABLE IF EXISTS table_with_partition") + sql("DROP TABLE IF EXISTS createAndInsertTest") + } + } + test("SPARK-21739: Cast expression should initialize timezoneId") { withTable("table_with_timestamp_partition") { sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") From 3990daaf3b6ca2c5a9f7790030096262efb12cb2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 08:55:01 -0500 Subject: [PATCH 325/593] [SPARK-23948] Trigger mapstage's job listener in submitMissingTasks ## What changes were proposed in this pull request? SparkContext submitted a map stage from `submitMapStage` to `DAGScheduler`, `markMapStageJobAsFinished` is called only in (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L933 and https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1314); But think about below scenario: 1. stage0 and stage1 are all `ShuffleMapStage` and stage1 depends on stage0; 2. We submit stage1 by `submitMapStage`; 3. When stage 1 running, `FetchFailed` happened, stage0 and stage1 got resubmitted as stage0_1 and stage1_1; 4. When stage0_1 running, speculated tasks in old stage1 come as succeeded, but stage1 is not inside `runningStages`. So even though all splits(including the speculated tasks) in stage1 succeeded, job listener in stage1 will not be called; 5. stage0_1 finished, stage1_1 starts running. When `submitMissingTasks`, there is no missing tasks. But in current code, job listener is not triggered. We should call the job listener for map stage in `5`. ## How was this patch tested? Not added yet. Author: jinxing Closes #21019 from jinxing64/SPARK-23948. --- .../apache/spark/scheduler/DAGScheduler.scala | 33 ++++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 52 +++++++++++++++++++ 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c46a84323392..78b6b34b5d2bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1092,17 +1092,16 @@ class DAGScheduler( // the stage as completed here in case there are no tasks to run markStageAsFinished(stage, None) - val debugString = stage match { + stage match { case stage: ShuffleMapStage => - s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) case stage : ResultStage => - s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") } - logDebug(debugString) - submitWaitingChildStages(stage) } } @@ -1307,13 +1306,7 @@ class DAGScheduler( shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - // Mark any map-stage jobs waiting on this stage as finished - if (shuffleStage.mapStageJobs.nonEmpty) { - val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) - for (job <- shuffleStage.mapStageJobs) { - markMapStageJobAsFinished(job, stats) - } - } + markMapStageJobsAsFinished(shuffleStage) submitWaitingChildStages(shuffleStage) } } @@ -1433,6 +1426,16 @@ class DAGScheduler( } } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + /** * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d812b5bd92c1b..8b6ec37625eec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2146,6 +2146,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Trigger mapstage's job listener in submitMissingTasks") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + + // Complete the stage0. + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting stage1, trigger a fetch failure. + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + // Stage1 listener should not have a result yet + assert(listener2.results.size === 0) + + // Speculative task succeeded in stage1. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostD", rdd2.partitions.length))) + // stage1 listener still should not have a result, though there's no missing partitions + // in it. Because stage1 has been failed and is not inside `runningStages` at this moment. + assert(listener2.results.size === 0) + + // Stage0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // After stage0 is finished, stage1 will be submitted and found there is no missing + // partitions in it. Then listener got triggered. + assert(listener2.results.size === 1) + assertDataStructuresEmpty() + } + /** * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported From f39e82ce150b6a7ea038e6858ba7adbaba3cad88 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Apr 2018 00:35:44 +0800 Subject: [PATCH 326/593] [SPARK-23986][SQL] freshName can generate non-unique names ## What changes were proposed in this pull request? We are using `CodegenContext.freshName` to get a unique name for any new variable we are adding. Unfortunately, this method currently fails to create a unique name when we request more than one instance of variables with starting name `name1` and an instance with starting name `name11`. The PR changes the way a new name is generated by `CodegenContext.freshName` so that we generate unique names in this scenario too. ## How was this patch tested? added UT Author: Marco Gaido Closes #21080 from mgaido91/SPARK-23986. --- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 +++-------- .../catalyst/expressions/CodeGenerationSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d97611c98ac91..f6b6775923ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -572,14 +572,9 @@ class CodegenContext { } else { s"${freshNamePrefix}_$name" } - if (freshNameIds.contains(fullName)) { - val id = freshNameIds(fullName) - freshNameIds(fullName) = id + 1 - s"$fullName$id" - } else { - freshNameIds += fullName -> 1 - fullName - } + val id = freshNameIds.getOrElse(fullName, 0) + freshNameIds(fullName) = id + 1 + s"${fullName}_$id" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f7c023111ff59..5b71becee2de0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -489,4 +489,14 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(!ctx.subExprEliminationExprs.contains(ref)) } } + + test("SPARK-23986: freshName can generate duplicated names") { + val ctx = new CodegenContext + val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: + ctx.freshName("myName11") :: Nil + assert(names1.distinct.length == 3) + val names2 = ctx.freshName("a") :: ctx.freshName("a") :: + ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil + assert(names2.distinct.length == 4) + } } From 1ca3c50fefb34532c78427fa74872db3ecbf7ba2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 17 Apr 2018 10:11:08 -0700 Subject: [PATCH 327/593] [SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariate summarizer ## What changes were proposed in this pull request? Python API for DataFrame-based multivariate summarizer. ## How was this patch tested? doctest added. Author: WeichenXu Closes #20695 from WeichenXu123/py_summarizer. --- python/pyspark/ml/stat.py | 193 +++++++++++++++++++++++++++++++++++++- 1 file changed, 192 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 93d0f4fd9148f..a06ab31a7a56a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -19,7 +19,9 @@ from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.wrapper import _jvm +from pyspark.ml.wrapper import JavaWrapper, _jvm +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.functions import lit class ChiSquareTest(object): @@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params): _jvm().PythonUtils.toSeq(params))) +class Summarizer(object): + """ + .. note:: Experimental + + Tools for vectorized statistics on MLlib Vectors. + The methods in this package provide various statistics for Vectors contained inside DataFrames. + This class lets users pick the statistics they would like to extract for a given column. + + >>> from pyspark.ml.stat import Summarizer + >>> from pyspark.sql import Row + >>> from pyspark.ml.linalg import Vectors + >>> summarizer = Summarizer.metrics("mean", "count") + >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + +-----------------------------------+ + |aggregate_metrics(features, weight)| + +-----------------------------------+ + |[[1.0,1.0,1.0], 1] | + +-----------------------------------+ + + >>> df.select(summarizer.summary(df.features)).show(truncate=False) + +--------------------------------+ + |aggregate_metrics(features, 1.0)| + +--------------------------------+ + |[[1.0,1.5,2.0], 2] | + +--------------------------------+ + + >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.0,1.0] | + +--------------+ + + >>> df.select(Summarizer.mean(df.features)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.5,2.0] | + +--------------+ + + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def mean(col, weightCol=None): + """ + return a column of mean summary + """ + return Summarizer._get_single_metric(col, weightCol, "mean") + + @staticmethod + @since("2.4.0") + def variance(col, weightCol=None): + """ + return a column of variance summary + """ + return Summarizer._get_single_metric(col, weightCol, "variance") + + @staticmethod + @since("2.4.0") + def count(col, weightCol=None): + """ + return a column of count summary + """ + return Summarizer._get_single_metric(col, weightCol, "count") + + @staticmethod + @since("2.4.0") + def numNonZeros(col, weightCol=None): + """ + return a column of numNonZero summary + """ + return Summarizer._get_single_metric(col, weightCol, "numNonZeros") + + @staticmethod + @since("2.4.0") + def max(col, weightCol=None): + """ + return a column of max summary + """ + return Summarizer._get_single_metric(col, weightCol, "max") + + @staticmethod + @since("2.4.0") + def min(col, weightCol=None): + """ + return a column of min summary + """ + return Summarizer._get_single_metric(col, weightCol, "min") + + @staticmethod + @since("2.4.0") + def normL1(col, weightCol=None): + """ + return a column of normL1 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL1") + + @staticmethod + @since("2.4.0") + def normL2(col, weightCol=None): + """ + return a column of normL2 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL2") + + @staticmethod + def _check_param(featuresCol, weightCol): + if weightCol is None: + weightCol = lit(1.0) + if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column): + raise TypeError("featureCol and weightCol should be a Column") + return featuresCol, weightCol + + @staticmethod + def _get_single_metric(col, weightCol, metric): + col, weightCol = Summarizer._check_param(col, weightCol) + return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric, + col._jc, weightCol._jc)) + + @staticmethod + @since("2.4.0") + def metrics(*metrics): + """ + Given a list of metrics, provides a builder that it turns computes metrics from a column. + + See the documentation of [[Summarizer]] for an example. + + The following metrics are accepted (case sensitive): + - mean: a vector that contains the coefficient-wise mean. + - variance: a vector tha contains the coefficient-wise variance. + - count: the count of all vectors seen. + - numNonzeros: a vector with the number of non-zeros for each coefficients + - max: the maximum for each coefficient. + - min: the minimum for each coefficient. + - normL2: the Euclidian norm for each coefficient. + - normL1: the L1 norm of each coefficient (sum of the absolute values). + + :param metrics: + metrics that can be provided. + :return: + an object of :py:class:`pyspark.ml.stat.SummaryBuilder` + + Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + interface. + """ + sc = SparkContext._active_spark_context + js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics", + _to_seq(sc, metrics)) + return SummaryBuilder(js) + + +class SummaryBuilder(JavaWrapper): + """ + .. note:: Experimental + + A builder object that provides summary statistics about a given column. + + Users should not directly create such builders, but instead use one of the methods in + :py:class:`pyspark.ml.stat.Summarizer` + + .. versionadded:: 2.4.0 + + """ + def __init__(self, jSummaryBuilder): + super(SummaryBuilder, self).__init__(jSummaryBuilder) + + @since("2.4.0") + def summary(self, featuresCol, weightCol=None): + """ + Returns an aggregate object that contains the summary of the column with the requested + metrics. + + :param featuresCol: + a column that contains features Vector object. + :param weightCol: + a column that contains weight value. Default weight is 1.0. + :return: + an aggregate column that contains the statistics. The exact content of this + structure is determined during the creation of the builder. + """ + featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol) + return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 5fccdae18911793967b315c02c058eb737e46174 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Apr 2018 21:08:42 -0500 Subject: [PATCH 328/593] [SPARK-22968][DSTREAM] Throw an exception on partition revoking issue ## What changes were proposed in this pull request? Kafka partitions can be revoked when new consumers joined in the consumer group to rebalance the partitions. But current Spark Kafka connector code makes sure there's no partition revoking scenarios, so trying to get latest offset from revoked partitions will throw exceptions as JIRA mentioned. Partition revoking happens when new consumer joined the consumer group, which means different streaming apps are trying to use same group id. This is fundamentally not correct, different apps should use different consumer group. So instead of throwing an confused exception from Kafka, improve the exception message by identifying revoked partition and directly throw an meaningful exception when partition is revoked. Besides, this PR also fixes bugs in `DirectKafkaWordCount`, this example simply cannot be worked without the fix. ``` 8/01/05 09:48:27 INFO internals.ConsumerCoordinator: Revoking previously assigned partitions [kssh-7, kssh-4, kssh-3, kssh-6, kssh-5, kssh-0, kssh-2, kssh-1] for group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: (Re-)joining group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: Successfully joined group use_a_separate_group_id_for_each_stream with generation 4 18/01/05 09:48:27 INFO internals.ConsumerCoordinator: Setting newly assigned partitions [kssh-7, kssh-4, kssh-6, kssh-5] for group use_a_separate_group_id_for_each_stream ``` ## How was this patch tested? This is manually verified in local cluster, unfortunately I'm not sure how to simulate it in UT, so propose the PR without UT added. Author: jerryshao Closes #21038 from jerryshao/SPARK-22968. --- .../streaming/DirectKafkaWordCount.scala | 17 +++++++++++++---- .../kafka010/DirectKafkaInputDStream.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index def06026bde96..2082fb71afdf1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -18,6 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka010._ @@ -26,18 +29,20 @@ import org.apache.spark.streaming.kafka010._ * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: DirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ object DirectKafkaWordCount { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { System.err.println(s""" |Usage: DirectKafkaWordCount | is a list of one or more Kafka brokers + | is a consumer group name to consume from topics | is a list of one or more kafka topics to consume from | """.stripMargin) @@ -46,7 +51,7 @@ object DirectKafkaWordCount { StreamingExamples.setStreamingLogLevels() - val Array(brokers, topics) = args + val Array(brokers, groupId, topics) = args // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") @@ -54,7 +59,11 @@ object DirectKafkaWordCount { // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet - val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val kafkaParams = Map[String, Object]( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers, + ConsumerConfig.GROUP_ID_CONFIG -> groupId, + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer], + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer]) val messages = KafkaUtils.createDirectStream[String, String]( ssc, LocationStrategies.PreferConsistent, diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 215b7cab703fb..c3221481556f5 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -190,8 +190,20 @@ private[spark] class DirectKafkaInputDStream[K, V]( // make sure new partitions are reflected in currentOffsets val newPartitions = parts.diff(currentOffsets.keySet) + + // Check if there's any partition been revoked because of consumer rebalance. + val revokedPartitions = currentOffsets.keySet.diff(parts) + if (revokedPartitions.nonEmpty) { + throw new IllegalStateException(s"Previously tracked partitions " + + s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " + + s"rebalance. This is mostly due to another stream with same group id joined, " + + s"please check if there're different streaming application misconfigure to use same " + + s"group id. Fundamentally different stream should use different group id") + } + // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap + // don't want to consume messages, so pause c.pause(newPartitions.asJava) // find latest available offsets From 1e3b8762a854a07c317f69fba7fa1a7bcdc58ff3 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Apr 2018 10:36:41 +0800 Subject: [PATCH 329/593] [SPARK-21479][SQL] Outer join filter pushdown in null supplying table when condition is on one of the joined columns ## What changes were proposed in this pull request? Added `TransitPredicateInOuterJoin` optimization rule that transits constraints from the preserved side of an outer join to the null-supplying side. The constraints of the join operator will remain unchanged. ## How was this patch tested? Added 3 tests in `InferFiltersFromConstraintsSuite`. Author: maryannxue Closes #20816 from maryannxue/spark-21479. --- .../sql/catalyst/optimizer/Optimizer.scala | 42 +++++++++++++++++-- .../plans/logical/QueryPlanConstraints.scala | 25 +++++++++-- .../InferFiltersFromConstraintsSuite.scala | 36 ++++++++++++++++ 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5fb59ef350b8b..913354e4df0e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,8 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and - * LeftSemi joins. + * In addition, for left/right outer joins, infer predicate from the preserved side of the Join + * operator and push the inferred filter over to the null-supplying side. For example, if the + * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in + * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will + * be applied to the null-supplying side. */ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { @@ -671,11 +674,42 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe val newConditionOpt = conditionOpt match { case Some(condition) => val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None + if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt case None => additionalConstraints.reduceOption(And) } - if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join + // Infer filter for left/right outer joins + val newLeftOpt = joinType match { + case RightOuter if newConditionOpt.isDefined => + val inferredConstraints = left.getRelevantConstraints( + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(left.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, left)) + case _ => None + } + val newRightOpt = joinType match { + case LeftOuter if newConditionOpt.isDefined => + val inferredConstraints = right.getRelevantConstraints( + right.constraints + .union(left.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(right.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, right)) + case _ => None + } + + if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) + || newLeftOpt.isDefined || newRightOpt.isDefined) { + Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) + } else { + join + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..a29f3d29236c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -41,9 +41,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - }) + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -55,6 +53,23 @@ trait QueryPlanConstraints { self: LogicalPlan => */ protected def validConstraints: Set[Expression] = Set.empty + /** + * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as + * equality constraints and `isNotNull` constraints, etc., and that only contains references + * to this [[LogicalPlan]] node. + */ + def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { + val allRelevantConstraints = + if (conf.constraintPropagationEnabled) { + constraints + .union(inferAdditionalConstraints(constraints)) + .union(constructIsNotNullConstraints(constraints)) + } else { + constraints + } + ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this @@ -120,4 +135,8 @@ trait QueryPlanConstraints { self: LogicalPlan => destination: Attribute): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) + + private def selfReferenceOnly(e: Expression): Boolean = { + e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..e068f51044589 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze + val left = x.where(IsNotNull('a) && 'a === 2) + val right = y.where(IsNotNull('a) && 'a === 2) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze + val left = x.where(IsNotNull('a) && 'a > 5) + val right = y.where(IsNotNull('a) && 'a > 5) + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join no filter push down to preserved side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a) && 'a === 1) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 310a8cd06299e434d94a1e391a6eb62944112446 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Apr 2018 11:51:10 +0800 Subject: [PATCH 330/593] [SPARK-23341][SQL] define some standard options for data source v2 ## What changes were proposed in this pull request? Each data source implementation can define its own options and teach its users how to set them. Spark doesn't have any restrictions about what options a data source should or should not have. It's possible that some options are very common and many data sources use them. However different data sources may define the common options(key and meaning) differently, which is quite confusing to end users. This PR defines some standard options that data sources can optionally adopt: path, table and database. ## How was this patch tested? a new test case. Author: Wenchen Fan Closes #20535 from cloud-fan/options. --- .../sql/sources/v2/DataSourceOptions.java | 100 ++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 14 ++- .../sources/v2/DataSourceOptionsSuite.scala | 25 +++++ 3 files changed, 135 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index c32053580f016..83df3be747085 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -17,16 +17,61 @@ package org.apache.spark.sql.sources.v2; +import java.io.IOException; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; + +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.spark.annotation.InterfaceStability; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent * data source options. + * + * Each data source implementation can define its own options and teach its users how to set them. + * Spark doesn't have any restrictions about what options a data source should or should not have. + * Instead Spark defines some standard options that data sources can optionally adopt. It's possible + * that some options are very common and many data sources use them. However different data + * sources may define the common options(key and meaning) differently, which is quite confusing to + * end users. + * + * The standard options defined by Spark: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
    Option keyOption value
    pathA path string of the data files/directories, like + * path1, /absolute/file2, path3/*. The path can + * either be relative or absolute, points to either file or directory, and can contain + * wildcards. This option is commonly used by file-based data sources.
    pathsA JSON array style paths string of the data files/directories, like + * ["path1", "/absolute/file2"]. The format of each path is same as the + * path option, plus it should follow JSON string literal format, e.g. quotes + * should be escaped, pa\"th means pa"th. + *
    tableA table name string representing the table name directly without any interpretation. + * For example, db.tbl means a table called db.tbl, not a table called tbl + * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
    databaseA database name string representing the database name directly without any + * interpretation, which is very similar to the table name option.
    */ @InterfaceStability.Evolving public class DataSourceOptions { @@ -97,4 +142,59 @@ public double getDouble(String key, double defaultValue) { return keyLowerCasedMap.containsKey(lcaseKey) ? Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; } + + /** + * The option key for singular path. + */ + public static final String PATH_KEY = "path"; + + /** + * The option key for multiple paths. + */ + public static final String PATHS_KEY = "paths"; + + /** + * The option key for table name. + */ + public static final String TABLE_KEY = "table"; + + /** + * The option key for database name. + */ + public static final String DATABASE_KEY = "database"; + + /** + * Returns all the paths specified by both the singular path option and the multiple + * paths option. + */ + public String[] paths() { + String[] singularPath = + get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); + Optional pathsStr = get(PATHS_KEY); + if (pathsStr.isPresent()) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); + return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); + } catch (IOException e) { + return singularPath; + } + } else { + return singularPath; + } + } + + /** + * Returns the value of the table name option. + */ + public Optional tableName() { + return get(TABLE_KEY); + } + + /** + * Returns the value of the database name option. + */ + public Optional databaseName() { + return get(DATABASE_KEY); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ae3ba1690f696..d640fdc530ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,8 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD @@ -34,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -171,7 +173,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` + // force invocation of `load(...varargs...)` + option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) } /** @@ -193,10 +196,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + } Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, extraOptions.toMap ++ sessionOptions, + ds, extraOptions.toMap ++ sessionOptions + pathsOption, userSpecifiedSchema = userSpecifiedSchema)) - } else { loadV1Source(paths: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 31dfc55b23361..cfa69a86de1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -79,4 +79,29 @@ class DataSourceOptionsSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } + + test("standard options") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.TABLE_KEY -> "tbl").asJava) + + assert(options.paths().toSeq == Seq("abc")) + assert(options.tableName().get() == "tbl") + assert(!options.databaseName().isPresent) + } + + test("standard options with both singular path and multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) + + assert(options.paths().toSeq == Seq("abc", "c", "d")) + } + + test("standard options with only multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) + + assert(options.paths().toSeq == Seq("c", "d\"e")) + } } From cce469435d61bda5893d9aa6cfdf7ea46fa717df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 17 Apr 2018 21:03:57 -0700 Subject: [PATCH 331/593] [SPARK-24002][SQL] Task not serializable caused by org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes ## What changes were proposed in this pull request? ``` Py4JJavaError: An error occurred while calling o153.sql. : org.apache.spark.SparkException: Job aborted. at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:223) at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:189) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:70) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:68) at org.apache.spark.sql.execution.command.ExecutedCommandExec.executeCollect(commands.scala:79) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$59.apply(Dataset.scala:3021) at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:89) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:127) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3020) at org.apache.spark.sql.Dataset.(Dataset.scala:190) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:74) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:646) at sun.reflect.GeneratedMethodAccessor153.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:293) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:226) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Exception thrown in Future.get: at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:190) at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:267) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec.doConsume(BroadcastNestedLoopJoinExec.scala:530) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.ProjectExec.consume(basicPhysicalOperators.scala:37) at org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:69) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.FilterExec.consume(basicPhysicalOperators.scala:144) ... at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:190) ... 23 more Caused by: java.util.concurrent.ExecutionException: org.apache.spark.SparkException: Task not serializable at java.util.concurrent.FutureTask.report(FutureTask.java:122) at java.util.concurrent.FutureTask.get(FutureTask.java:206) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:179) ... 276 more Caused by: org.apache.spark.SparkException: Task not serializable at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:340) at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:330) at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:156) at org.apache.spark.SparkContext.clean(SparkContext.scala:2380) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:850) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:849) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:371) at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:849) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:417) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:89) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:125) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:271) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.inputRDDs(HashAggregateExec.scala:181) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:414) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:61) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:70) at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:264) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:93) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:81) at org.apache.spark.sql.execution.SQLExecution$.withExecutionId(SQLExecution.scala:150) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:80) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:76) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.nio.BufferUnderflowException at java.nio.HeapByteBuffer.get(HeapByteBuffer.java:151) at java.nio.ByteBuffer.get(ByteBuffer.java:715) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes(Binary.java:405) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytesUnsafe(Binary.java:414) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.writeObject(Binary.java:484) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1128) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496) ``` The Parquet filters are serializable but not thread safe. SparkPlan.prepare() could be called in different threads (BroadcastExchange will call it in a thread pool). Thus, we could serialize the same Parquet filter at the same time. This is not easily reproduced. The fix is to avoid serializing these Parquet filters in the driver. This PR is to avoid serializing these Parquet filters by moving the parquet filter generation from the driver to executors. ## How was this patch tested? Having two queries one is a 1000-line SQL query and a 3000-line SQL query. Need to run at least one hour with a heavy write workload to reproduce once. Author: gatorsmile Closes #21086 from gatorsmile/taskNotSerializable. --- .../parquet/ParquetFileFormat.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 476bd02374364..d8f47eec952de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -321,19 +321,6 @@ class ParquetFileFormat SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - // Try to push down filters when filter push-down is enabled. - val pushed = - if (sparkSession.sessionState.conf.parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -351,12 +338,26 @@ class ParquetFileFormat val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = + sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) From f81fa478ff990146e2a8e463ac252271448d96f5 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 18 Apr 2018 18:41:55 +0900 Subject: [PATCH 332/593] [SPARK-23926][SQL] Extending reverse function to support ArrayType arguments ## What changes were proposed in this pull request? This PR extends `reverse` functions to be able to operate over array columns and covers: - Introduction of `Reverse` expression that represents logic for reversing arrays and also strings - Removal of `StringReverse` expression - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(1, 3, 4, 2), null ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = inputadapter_value.copy(); /* 051 */ for(int k = 0; k < project_length / 2; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ boolean isNullAtK = project_value.isNullAt(k); /* 054 */ boolean isNullAtL = project_value.isNullAt(l); /* 055 */ if(!isNullAtK) { /* 056 */ int el = project_value.getInt(k); /* 057 */ if(!isNullAtL) { /* 058 */ project_value.setInt(k, project_value.getInt(l)); /* 059 */ } else { /* 060 */ project_value.setNullAt(k); /* 061 */ } /* 062 */ project_value.setInt(l, el); /* 063 */ } else if (!isNullAtL) { /* 064 */ project_value.setInt(k, project_value.getInt(l)); /* 065 */ project_value.setNullAt(l); /* 066 */ } /* 067 */ } /* 068 */ /* 069 */ } ``` ### Non-primitive type ``` val df = Seq( Seq("a", "c", "d", "b"), null ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]); /* 051 */ for(int k = 0; k < project_length; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l)); /* 054 */ } /* 055 */ /* 056 */ } ``` Author: mn-mikke Closes #21034 from mn-mikke/feature/array-api-reverse-to-master. --- python/pyspark/sql/functions.py | 20 +++- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 88 +++++++++++++++++ .../expressions/stringExpressions.scala | 20 ---- .../CollectionExpressionsSuite.scala | 44 +++++++++ .../expressions/StringExpressionsSuite.scala | 6 +- .../org/apache/spark/sql/functions.scala | 15 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 94 +++++++++++++++++++ 8 files changed, 256 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6ca22b610843d..d3bb0a5d6b36a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1414,7 +1414,6 @@ def hash(*cols): 'uppercase. Words are delimited by whitespace.', 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', - 'reverse': 'Reverses the string column and returns it as a new string column.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', @@ -2128,6 +2127,25 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4dd1ca509bf2c..38c874ad948e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -336,7 +336,6 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), - expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), @@ -411,6 +410,7 @@ object FunctionRegistry { expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), + expression[Reverse]("reverse"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c87777eed47a..76b71f5b86074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + /** * Checks if the array (left) has the element (right) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 22fbb8998ed89..5a02ca0d6862c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression) } } -/** - * Returns the reversed given string. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns the reversed given string.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - """) -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - /** * Returns a string consisting of n spaces. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5a31e3a30edd6..517639dbc7232 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + + test("Reverse") { + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai7 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2)) + checkEvaluation(Reverse(ai1), Seq(3, 1, 2)) + checkEvaluation(Reverse(ai2), Seq(3, null, 1, null)) + checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2)) + checkEvaluation(Reverse(ai4), Seq(null, null, null)) + checkEvaluation(Reverse(ai5), Seq(1)) + checkEvaluation(Reverse(ai6), Seq.empty) + checkEvaluation(Reverse(ai7), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType)) + val as7 = Literal.create(null, ArrayType(StringType)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b")) + checkEvaluation(Reverse(as1), Seq("c", "a", "b")) + checkEvaluation(Reverse(as2), Seq("c", null, "a", null)) + checkEvaluation(Reverse(as3), Seq(null, "d", null, "b")) + checkEvaluation(Reverse(as4), Seq(null, null, null)) + checkEvaluation(Reverse(as5), Seq("a")) + checkEvaluation(Reverse(as6), Seq.empty) + checkEvaluation(Reverse(as7), null) + checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9a1a4da074ce3..f1a6f9b8889fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("REVERSE") { val s = 'a.string.at(0) val row1 = create_row("abccc") - checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) - checkEvaluation(StringReverse(s), "cccba", row1) - checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) + checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) + checkEvaluation(Reverse(s), "cccba", row1) + checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 642ac056bb809..a55a800f48245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2464,14 +2464,6 @@ object functions { StringRepeat(str.expr, lit(n).expr) } - /** - * Reverses the string column and returns it as a new string column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } - /** * Trim the spaces from right end for the specified string value. * @@ -3316,6 +3308,13 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a reversed string or an array with reverse order of elements. + * @group collection_funcs + * @since 1.5.0 + */ + def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 636e86baedf6f..74c42f2599dca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("reverse function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + + // String test cases + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + + checkAnswer( + oneRowDF.select(reverse('s)), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(s)"), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.select(reverse('i)), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(i)"), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(null)"), + Seq(Row(null)) + ) + + // Array test cases (primitive-type elements) + val idf = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + + // Array test cases (non-primitive-type elements) + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(struct(1, 'a'))") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(map(1, 'a'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From f09a9e9418c1697d198de18f340b1288f5eb025c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 18 Apr 2018 08:22:05 -0700 Subject: [PATCH 333/593] [SPARK-24007][SQL] EqualNullSafe for FloatType and DoubleType might generate a wrong result by codegen. ## What changes were proposed in this pull request? `EqualNullSafe` for `FloatType` and `DoubleType` might generate a wrong result by codegen. ```scala scala> val df = Seq((Some(-1.0d), None), (None, Some(-1.0d))).toDF() df: org.apache.spark.sql.DataFrame = [_1: double, _2: double] scala> df.show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ scala> df.filter("_1 <=> _2").show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ ``` The result should be empty but the result remains two rows. ## How was this patch tested? Added a test. Author: Takuya UESHIN Closes #21094 from ueshin/issues/SPARK-24007/equalnullsafe. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 ++++-- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6b6775923ac6..cf0a91ff00626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,8 +582,10 @@ class CodegenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" - case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" - case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" + case FloatType => + s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" + case DoubleType => + s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 8a8f8e10225fa..1bfd180ae4393 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -442,4 +442,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") { + checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false) + checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) + } } From a9066478f6d98c3ae634c3bb9b09ee20bd60e111 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 00:05:47 +0200 Subject: [PATCH 334/593] [SPARK-23875][SQL][FOLLOWUP] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? Use specified accessor in `ArrayData.foreach` and `toArray`. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21099 from viirya/SPARK-23875-followup. --- .../org/apache/spark/sql/catalyst/util/ArrayData.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 2cf59d567c08c..104b428614849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -141,28 +141,29 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType).asInstanceOf[T] + values(i) = accessor(this, i).asInstanceOf[T] } i += 1 } values } - // todo: specialize this. def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { if (isNullAt(i)) { f(i, null) } else { - f(i, get(i, elementType)) + f(i, accessor(this, i)) } i += 1 } From 0c94e48bc50717e1627c0d2acd5382d9adc73c97 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 18 Apr 2018 16:37:41 -0700 Subject: [PATCH 335/593] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `wait` and `CountDownLatch` for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #20888 from gaborgsomogyi/SPARK-23775. --- .../spark/sql/DataFrameRangeSuite.scala | 78 +++++++++++-------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..a0fd74088ce8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import java.util.concurrent.{CountDownLatch, TimeUnit} + import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -152,39 +154,53 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) + // Save and restore the value because SparkContext is shared + val savedInterruptOnCancel = sparkContext + .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + + try { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") + + for (codegen <- Seq(true, false)) { + // This countdown latch used to make sure with all the stages cancelStage called in listener + val latch = new CountDownLatch(2) + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + latch.countDown() + } } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) - } - } - sparkContext.addSparkListener(listener) - for (codegen <- Seq(true, false)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 - val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } + x + }.toDF("id").agg(sum("id")).collect() + } + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + } } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + latch.await(20, TimeUnit.SECONDS) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sparkContext.removeSparkListener(listener) } - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } + } finally { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, + savedInterruptOnCancel) } - sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -204,7 +220,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From 8bb0df2c65355dfdcd28e362ff661c6c7ebc99c0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 10:00:57 +0800 Subject: [PATCH 336/593] [SPARK-24014][PYSPARK] Add onStreamingStarted method to StreamingListener ## What changes were proposed in this pull request? The `StreamingListener` in PySpark side seems to be lack of `onStreamingStarted` method. This patch adds it and a test for it. This patch also includes a trivial doc improvement for `createDirectStream`. Original PR is #21057. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21098 from viirya/SPARK-24014. --- python/pyspark/streaming/kafka.py | 3 ++- python/pyspark/streaming/listener.py | 6 ++++++ python/pyspark/streaming/tests.py | 7 +++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index fdb9308604489..ed2e0e7d10fa2 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :param topics: list of topic_name to consume. :param kafkaParams: Additional params for Kafka. :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting - point of the stream. + point of the stream (a dictionary mapping `TopicAndPartition` to + integers). :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py index b830797f5c0a0..d4ecc215aea99 100644 --- a/python/pyspark/streaming/listener.py +++ b/python/pyspark/streaming/listener.py @@ -23,6 +23,12 @@ class StreamingListener(object): def __init__(self): pass + def onStreamingStarted(self, streamingStarted): + """ + Called when the streaming has been started. + """ + pass + def onReceiverStarted(self, receiverStarted): """ Called when a receiver has been started diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7dde7c0928c08..103940923dd4d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -507,6 +507,10 @@ def __init__(self): self.batchInfosCompleted = [] self.batchInfosStarted = [] self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) def onBatchSubmitted(self, batchSubmitted): self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) @@ -530,9 +534,12 @@ def func(dstream): batchInfosSubmitted = batch_collector.batchInfosSubmitted batchInfosStarted = batch_collector.batchInfosStarted batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime self.wait_for(batchInfosCompleted, 4) + self.assertEqual(len(streamingStartedTime), 1) + self.assertGreaterEqual(len(batchInfosSubmitted), 4) for info in batchInfosSubmitted: self.assertGreaterEqual(info.batchTime().milliseconds(), 0) From d5bec48b9cb225c19b43935c07b24090c51cacce Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 11:59:17 +0900 Subject: [PATCH 337/593] [SPARK-23919][SQL] Add array_position function ## What changes were proposed in this pull request? The PR adds the SQL function `array_position`. The behavior of the function is based on Presto's one. The function returns the position of the first occurrence of the element in array x (or 0 if not found) using 1-based index as BigInt. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21037 from kiszk/SPARK-23919. --- python/pyspark/sql/functions.py | 17 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 56 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 22 ++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 34 +++++++++++ 6 files changed, 144 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d3bb0a5d6b36a..36dcabc6766d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,6 +1845,23 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def array_position(col, value): + """ + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. + + .. note:: The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38c874ad948e1..74095fe697b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 76b71f5b86074..e6a05f535cb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 517639dbc7232..916cd3bb4cca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Reverse(as7), null) checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) } + + test("Array Position") { + val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) + checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) + checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) + checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) + checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayPosition(a3, Literal("")), null) + checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a55a800f48245..3a09ec4f1982e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3038,6 +3038,20 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Locates the position of the first occurrence of the value in the given array as long. + * Returns null if either of the arguments are null. + * + * @note The position is not zero based, but 1 based index. Returns 0 if value + * could not be found in array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_position(column: Column, value: Any): Column = withExpr { + ArrayPosition(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 74c42f2599dca..13161e7e24cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array position function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + checkAnswer( + df.select(array_position(df("a"), 1)), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.selectExpr("array_position(a, 1)"), + Seq(Row(1L), Row(0L)) + ) + + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("array_position(a, null)"), + Seq(Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L), Row(1L)) + ) + checkAnswer( + df.selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L), Row(1L)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 46bb2b5129833cc5829089bf1174a76cb7b81741 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 21:00:10 +0900 Subject: [PATCH 338/593] [SPARK-23924][SQL] Add element_at function ## What changes were proposed in this pull request? The PR adds the SQL function `element_at`. The behavior of the function is based on Presto's one. This function returns element of array at given index in value if column is array, or returns value for the given key in value if column is map. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21053 from kiszk/SPARK-23924. --- python/pyspark/sql/functions.py | 24 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 104 ++++++++++++++++++ .../expressions/complexTypeExtractors.scala | 64 +++++++---- .../CollectionExpressionsSuite.scala | 48 ++++++++ .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 48 ++++++++ 7 files changed, 276 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 36dcabc6766d8..1be68f2a4a448 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,30 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def element_at(col, extraction): + """ + Collection function: Returns element of array at given index in extraction if col is array. + Returns value for the given key in extraction if col is map. + + :param col: name of column containing array or map + :param extraction: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 74095fe697b6a..a44f2d5272b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -405,6 +405,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6a05f535cb1c..dba426e999dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression) }) } } + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6cdad19168dce..3fba52d745453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression) }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 916cd3bb4cca5..7d8fe211858b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a09ec4f1982e..9c8580378303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3052,6 +3052,17 @@ object functions { ArrayPosition(column.expr, Literal(value)) } + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 13161e7e24cfe..7c976c1b7f915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("element_at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 1b08c4393cf48e21fea9914d130d8d3bf544061d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:38:26 +0200 Subject: [PATCH 339/593] [SPARK-23584][SQL] NewInstance should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `NewInstance`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20778 from maropu/SPARK-23584. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../expressions/objects/objects.scala | 28 +++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 36 +++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e4274aaa9727e..818cc2fb1e8a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Constructor + +import org.apache.commons.lang3.reflect.ConstructorUtils + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned. Otherwise, it returns `None`. + */ + def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + } + /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 72b202b3a5020..1645bd7d57b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -449,8 +449,32 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + sys.error(s"Couldn't find a valid constructor on $cls") + } + } + outerPointer.map { p => + val outerObj = p() + val d = outerObj.getClass +: paramTypes + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(outerObj +: args: _*) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(args: _*) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0188b0098def..bf805f4f29ac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass { override def binOp(e1: Int, e2: Double): Double = e1 - e2 } +// Tests for NewInstance +class Outer extends Serializable { + class Inner(val value: Int) { + override def hashCode(): Int = super.hashCode() + override def equals(other: Any): Boolean = { + if (other.isInstanceOf[Inner]) { + value == other.asInstanceOf[Inner].value + } else { + false + } + } + } +} + class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-16622: The returned value of the called method in Invoke can be null") { @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23584 NewInstance should support interpreted execution") { + // Normal case test + val newInst1 = NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None) + checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3))) + + // Inner class case test + val outerObj = new Outer() + val newInst2 = NewInstance( + cls = classOf[outerObj.Inner], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[outerObj.Inner]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + } + test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), @@ -421,6 +456,7 @@ class TestBean extends Serializable { private var x: Int = 0 def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = assert(i != null, "this setter should not be called with null.") } From e13416502f814b04d59bb650953a0114332d163a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:42:50 +0200 Subject: [PATCH 340/593] [SPARK-23588][SQL] CatalystToExternalMap should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `CatalystToExternalMap`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20979 from maropu/SPARK-23588. --- .../expressions/objects/objects.scala | 39 +++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 34 +++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1645bd7d57b1d..bc17d1229420a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,12 +28,12 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1033,8 +1033,39 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val keyConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) + private lazy val valueConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + val module = clazz.getField("MODULE$").get(null) + val method = clazz.getMethod("newBuilder") + method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + var i = 0 + while (i < result.numElements()) { + val key = keyConverter(keyArray.get(i, inputMapType.keyType)) + val value = valueConverter(valueArray.get(i, inputMapType.valueType)) + builder += Tuple2(key, value) + i += 1 + } + builder.result() + } else { + null + } + } override def dataType: DataType = ObjectType(collClass) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bf805f4f29ac5..bcd035c1eba0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -27,12 +27,14 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -162,9 +164,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), (DateTimeUtils.getClass, ObjectType(classOf[Date]), - "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777, + DateTimeUtils.toJavaDate(77777)), (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), - "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]), 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) ).foreach { case (cls, dataType, methodName, argType, arg, expected) => checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, @@ -450,6 +453,25 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]() + + test("SPARK-23588 CatalystToExternalMap should support interpreted execution") { + // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan + // with dummy input, resolve the plan by the analyzer, and replace the dummy input + // with a literal for tests. + val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) + val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) + + val analyzedPlan = SimpleAnalyzer.execute(plan) + val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head + + // Replaces the dummy input with a literal for tests here + val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val deserializer = toMapExpr.copy(inputData = Literal.create(data)) + checkObjectExprEvaluation(deserializer, expected = data) + } } class TestBean extends Serializable { From 9e10f69df52abde2de5d93435bab54e97dd59d9c Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 19 Apr 2018 21:07:21 +0800 Subject: [PATCH 341/593] [SPARK-22676][FOLLOW-UP] fix code style for test. ## What changes were proposed in this pull request? This pr address comments in https://github.com/apache/spark/pull/19868 ; Fix the code style for `org.apache.spark.sql.hive.QueryPartitionSuite` by using: `withTempView`, `withTempDir`, `withTable`... Author: jinxing Closes #21091 from jinxing64/SPARK-22676-FOLLOW-UP. --- .../spark/sql/hive/QueryPartitionSuite.scala | 109 +++++++----------- 1 file changed, 41 insertions(+), 68 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 78156b17fb43b..1e396553c9c52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -33,80 +33,53 @@ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ - test("SPARK-5068: query data when path doesn't exist") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + private def queryWhenPathNotExist(): Unit = { + withTempView("testData") { + withTable("table_with_partition", "createAndInsertTest") { + withTempDir { tmpDir => + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData).union(testData)) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData)) + } + } + } + } - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") { + queryWhenPathNotExist() } } test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") { sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + queryWhenPathNotExist() } } From d96c3e33cc2a95de8e15e1a2ddf50a8d0cc66dd2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 19 Apr 2018 21:21:22 +0800 Subject: [PATCH 342/593] [SPARK-21811][SQL] Fix the inconsistency behavior when finding the widest common type ## What changes were proposed in this pull request? Currently we find the wider common type by comparing the two types from left to right, this can be a problem when you have two data types which don't have a common type but each can be promoted to StringType. For instance, if you have a table with the schema: [c1: date, c2: string, c3: int] The following succeeds: SELECT coalesce(c1, c2, c3) FROM table While the following produces an exception: SELECT coalesce(c1, c3, c2) FROM table This is only a issue when the seq of dataTypes contains `StringType` and all the types can do string promotion. close #19033 ## How was this patch tested? Add test in `TypeCoercionSuite` Author: Xingbo Jiang Closes #21074 from jiangxb1987/typeCoercion. --- docs/sql-programming-guide.md | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 24 +++++++++++++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 13 ++++++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 55d35b9dd31db..e8ff1470970f7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1810,7 +1810,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ec7e7761dc4c2..281f206e8d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -175,11 +175,27 @@ object TypeCoercion { }) } + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt match { + case StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) - case None => None - }) + // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal + // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. + // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, + // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 8ac49dc05e3cf..fd6a3121663ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -539,6 +539,9 @@ class TypeCoercionSuite extends AnalysisTest { val floatLit = Literal.create(1.0f, FloatType) val timestampLit = Literal.create("2017-04-12", TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) + val strArrayLit = Literal(Array("c")) + val intArrayLit = Literal(Array(1)) ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), @@ -572,6 +575,16 @@ class TypeCoercionSuite extends AnalysisTest { Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, intLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), + Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), + Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), + Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) } test("CreateArray casts") { From 0deaa5251326a32a3d2d2b8851193ca926303972 Mon Sep 17 00:00:00 2001 From: wuyi Date: Thu, 19 Apr 2018 09:00:33 -0500 Subject: [PATCH 343/593] [SPARK-24021][CORE] fix bug in BlacklistTracker's updateBlacklistForFetchFailure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? There‘s a miswrite in BlacklistTracker's updateBlacklistForFetchFailure: ``` val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) blacklistedExecsOnNode += exec ``` where first **exec** should be **host**. ## How was this patch tested? adjust existed test. Author: wuyi Closes #21104 from Ngone51/SPARK-24021. --- .../scala/org/apache/spark/scheduler/BlacklistTracker.scala | 2 +- .../org/apache/spark/scheduler/BlacklistTrackerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 952598f6de19d..30cf75d43ee09 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -210,7 +210,7 @@ private[scheduler] class BlacklistTracker ( updateNextExpiryTime() killBlacklistedExecutor(exec) - val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]()) blacklistedExecsOnNode += exec } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 06d7afaaff55c..96c8404327e24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -574,6 +574,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. conf.set(config.BLACKLIST_KILL_ENABLED, true) blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) @@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) // Enable external shuffle service to see if all the executors on this node will be killed. conf.set(config.SHUFFLE_SERVICE_ENABLED, true) From 6e19f7683fc73fabe7cdaac4eb1982d2e3e607b7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Apr 2018 17:54:53 +0200 Subject: [PATCH 344/593] [SPARK-23989][SQL] exchange should copy data before non-serialized shuffle ## What changes were proposed in this pull request? In Spark SQL, we usually reuse the `UnsafeRow` instance and need to copy the data when a place buffers non-serialized objects. Shuffle may buffer objects if we don't make it to the bypass merge shuffle or unsafe shuffle. `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle` misses the case that, if `spark.sql.shuffle.partitions` is large enough, we could fail to run unsafe shuffle and go with the non-serialized shuffle. This bug is very hard to hit since users wouldn't set such a large number of partitions(16 million) for Spark SQL exchange. TODO: test ## How was this patch tested? todo. Author: Wenchen Fan Closes #21101 from cloud-fan/shuffle. --- .../exchange/ShuffleExchangeExec.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4d95ee34f30de..b89203719541b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -153,12 +153,9 @@ object ShuffleExchangeExec { * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. * * @param partitioner the partitioner for the shuffle - * @param serializer the serializer that will be used to write rows * @return true if rows should be copied before being shuffled, false otherwise */ - private def needToCopyObjectsBeforeShuffle( - partitioner: Partitioner, - serializer: Serializer): Boolean = { + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { // Note: even though we only use the partitioner's `numPartitions` field, we require it to be // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output @@ -167,22 +164,24 @@ object ShuffleExchangeExec { val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + val numParts = partitioner.numPartitions if (sortBasedShuffleOn) { - val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + if (numParts <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializer.supportsRelocationOfSerializedObjects) { + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records // prior to sorting them. This optimization is only applied in cases where shuffle // dependency does not specify an aggregator or ordering and the record serializer has - // certain properties. If this optimization is enabled, we can safely avoid the copy. + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. // - // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only - // need to check whether the optimization is enabled and supported by our serializer. + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. false } else { // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must @@ -298,7 +297,7 @@ object ShuffleExchangeExec { rdd } - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + if (needToCopyObjectsBeforeShuffle(part)) { newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } From a471880afbeafd4ef54c15a97e72ea7ff784a88d Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Thu, 19 Apr 2018 09:40:20 -0700 Subject: [PATCH 345/593] [SPARK-24026][ML] Add Power Iteration Clustering to spark.ml ## What changes were proposed in this pull request? This PR adds PowerIterationClustering as a Transformer to spark.ml. In the transform method, it calls spark.mllib's PowerIterationClustering.run() method and transforms the return value assignments (the Kmeans output of the pseudo-eigenvector) as a DataFrame (id: LongType, cluster: IntegerType). This PR is copied and modified from https://github.com/apache/spark/pull/15770 The primary author is wangmiao1981 ## How was this patch tested? This PR has 2 types of tests: * Copies of tests from spark.mllib's PIC tests * New tests specific to the spark.ml APIs Author: wm624@hotmail.com Author: wangmiao1981 Author: Joseph K. Bradley Closes #21090 from jkbradley/wangmiao1981-pic. --- .../clustering/PowerIterationClustering.scala | 256 ++++++++++++++++++ .../PowerIterationClusteringSuite.scala | 238 ++++++++++++++++ 2 files changed, 494 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..2c30a1d9aa947 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +/** + * Common params for PowerIterationClustering + */ +private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter + with HasPredictionCol { + + /** + * The number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.4.0") + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) + + /** @group getParam */ + @Since("2.4.0") + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use a normalized sum of similarities with other vertices. + * Default: random. + * @group expertParam + */ + @Since("2.4.0") + final val initMode = { + val allowedParams = ParamValidators.inArray(Array("random", "degree")) + new Param[String](this, "initMode", "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " + + "of similarities with other vertices. Supported options: 'random' and 'degree'.", + allowedParams) + } + + /** @group expertGetParam */ + @Since("2.4.0") + def getInitMode: String = $(initMode) + + /** + * Param for the name of the input column for vertex IDs. + * Default: "id" + * @group param + */ + @Since("2.4.0") + val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + (value: String) => value.nonEmpty) + + setDefault(idCol, "id") + + /** @group getParam */ + @Since("2.4.0") + def getIdCol: String = getOrDefault(idCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "neighbors" + * @group param + */ + @Since("2.4.0") + val neighborsCol = new Param[String](this, "neighborsCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(neighborsCol, "neighbors") + + /** @group getParam */ + @Since("2.4.0") + def getNeighborsCol: String = $(neighborsCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "similarities" + * @group param + */ + @Since("2.4.0") + val similaritiesCol = new Param[String](this, "similaritiesCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(similaritiesCol, "similarities") + + /** @group getParam */ + @Since("2.4.0") + def getSimilaritiesCol: String = $(similaritiesCol) + + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(schema, $(neighborsCol), + Seq(ArrayType(IntegerType, containsNull = false), + ArrayType(LongType, containsNull = false))) + SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), + Seq(ArrayType(FloatType, containsNull = false), + ArrayType(DoubleType, containsNull = false))) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * Lin and Cohen. From the abstract: + * PIC finds a very low-dimensional embedding of a dataset using truncated power + * iteration on a normalized pair-wise similarity matrix of the data. + * + * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix + * is a symmetric matrix whose entries are non-negative similarities between items. + * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: + * - `idCol`: vertex ID + * - `neighborsCol`: neighbors of vertex in `idCol` + * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex + * in `idCol` and each neighbor in `neighborsCol` + * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` + * containing the cluster assignment in `[0,k)` for each row (vertex). + * + * Notes: + * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. + * Transform runs the iterative PIC algorithm to cluster the whole input dataset. + * - Input validation: This validates that similarities are non-negative but does NOT validate + * that the input matrix is symmetric. + * + * @see + * Spectral clustering (Wikipedia) + */ +@Since("2.4.0") +@Experimental +class PowerIterationClustering private[clustering] ( + @Since("2.4.0") override val uid: String) + extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> "random") + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("PowerIterationClustering")) + + /** @group setParam */ + @Since("2.4.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + @Since("2.4.0") + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group setParam */ + @Since("2.4.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.4.0") + def setIdCol(value: String): this.type = set(idCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + + @Since("2.4.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + val sparkSession = dataset.sparkSession + val idColValue = $(idCol) + val rdd: RDD[(Long, Long, Double)] = + dataset.select( + col($(idCol)).cast(LongType), + col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), + col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) + ).rdd.flatMap { + case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => + require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + + s"equal to the the length of the neighbor similarity list. Row for ID " + + s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + + s"of length ${sims.length}.") + nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { + case (nbr, similarity) => (id, nbr, similarity) + } + } + val algorithm = new MLlibPowerIterationClustering() + .setK($(k)) + .setInitializationMode($(initMode)) + .setMaxIterations($(maxIter)) + val model = algorithm.run(rdd) + + val predictionsRDD: RDD[Row] = model.assignments.map { assignment => + Row(assignment.id, assignment.cluster) + } + + val predictionsSchema = StructType(Seq( + StructField($(idCol), LongType, nullable = false), + StructField($(predictionCol), IntegerType, nullable = false))) + val predictions = { + val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) + dataset.schema($(idCol)).dataType match { + case _: LongType => + uncastPredictions + case otherType => + uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + } + } + + dataset.join(predictions, $(idCol)) + } + + @Since("2.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra) +} + +@Since("2.4.0") +object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] { + + @Since("2.4.0") + override def load(path: String): PowerIterationClustering = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..65328df17baff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import scala.collection.mutable + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + + +class PowerIterationClusteringSuite extends SparkFunSuite + with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var data: Dataset[_] = _ + final val r1 = 1.0 + final val n1 = 10 + final val r2 = 4.0 + final val n2 = 40 + + override def beforeAll(): Unit = { + super.beforeAll() + + data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2) + } + + test("default parameters") { + val pic = new PowerIterationClustering() + + assert(pic.getK === 2) + assert(pic.getMaxIter === 20) + assert(pic.getInitMode === "random") + assert(pic.getPredictionCol === "prediction") + assert(pic.getIdCol === "id") + assert(pic.getNeighborsCol === "neighbors") + assert(pic.getSimilaritiesCol === "similarities") + } + + test("parameter validation") { + intercept[IllegalArgumentException] { + new PowerIterationClustering().setK(1) + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setIdCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setNeighborsCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setSimilaritiesCol("") + } + } + + test("power iteration clustering") { + val n = n1 + n2 + + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + val result = model.transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + + val result2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .transform(data) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + result2.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + } + assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + } + + test("supported input types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + + def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + val typedData = data.select( + col("id").cast(idType).alias("id"), + col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("similarities").cast(ArrayType(similarityType, containsNull = false)) + .alias("similarities") + ) + model.transform(typedData).collect() + } + + for (idType <- Seq(IntegerType, LongType)) { + runTest(idType, LongType, DoubleType) + } + for (neighborType <- Seq(IntegerType, LongType)) { + runTest(LongType, neighborType, DoubleType) + } + for (similarityType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, similarityType) + } + } + + test("invalid input: wrong types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id").cast(DoubleType).alias("id"), + col("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors"), + col("neighbors").alias("similarities") + ) + model.transform(typedData) + } + } + + test("invalid input: negative similarity") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(-1.0)), + (1, Array(0), Array(-1.0)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("Similarity must be nonnegative")) + } + + test("invalid input: mismatched lengths for neighbor and similarity arrays") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(0.5)), + (1, Array(0, 2), Array(0.5)), + (2, Array(1), Array(0.5)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + + "the neighbor similarity list.")) + assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + } + + test("read/write") { + val t = new PowerIterationClustering() + .setK(4) + .setMaxIter(100) + .setInitMode("degree") + .setIdCol("test_id") + .setNeighborsCol("myNeighborsCol") + .setSimilaritiesCol("mySimilaritiesCol") + .setPredictionCol("test_prediction") + testDefaultReadWrite(t) + } +} + +object PowerIterationClusteringSuite { + + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + + def generatePICData( + spark: SparkSession, + r1: Double, + r2: Double, + n1: Int, + n2: Int): DataFrame = { + // Generate two circles following the example in the PIC paper. + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + + val rows = for (i <- 1 until n) yield { + val neighbors = for (j <- 0 until i) yield { + j.toLong + } + val similarities = for (j <- 0 until i) yield { + sim(points(i), points(j)) + } + (i.toLong, neighbors.toArray, similarities.toArray) + } + + spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + } + +} From 9ea8d3d31b75246bf61118ac7934bc92c18b5f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 19 Apr 2018 18:55:59 +0200 Subject: [PATCH 346/593] [SPARK-22362][SQL] Add unit test for Window Aggregate Functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Improving the test coverage of window functions focusing on missing test for window aggregate functions. No new UDAF test is added as it has been tested already. ## How was this patch tested? Only new tests were added, automated tests were executed. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20046 from attilapiros/SPARK-22362. --- .../resources/sql-tests/inputs/window.sql | 10 +- .../sql-tests/results/window.sql.out | 30 +- .../sql/DataFrameWindowFunctionsSuite.scala | 266 ++++++++++++++++++ 3 files changed, 294 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c4bea34ec4cf3..cda4db4b449fe 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 133458ae9303b..4afbcd62853dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val -- !query 17 schema -struct +struct,collect_set:array,skewness:double,kurtosis:double> -- !query 17 output -NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 -2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 -3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 -- !query 18 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 281147835abde..3ea398aad7375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import scala.collection.mutable + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ @@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } + test("corr, covar_pop, stddev_pop functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } + + test("collect_list in ascending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_list in descending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_set in window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", "3"), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20")).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "3")), + Row("b", Array("1", "2", "3")), + Row("c", Array("1", "2", "3")), + Row("d", Array("1", "2", "3")), + Row("e", Array("1", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")))) + } + + test("skewness and kurtosis functions in window") { + val df = Seq( + ("a", "p1", 1.0), + ("b", "p1", 1.0), + ("c", "p1", 2.0), + ("d", "p1", 2.0), + ("e", "p1", 3.0), + ("f", "p1", 3.0), + ("g", "p1", 3.0), + ("h", "p2", 1.0), + ("i", "p2", 2.0), + ("j", "p2", 5.0)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + skewness("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() + Seq( + Row("a", -0.27238010581457267, -1.506920415224914), + Row("b", -0.27238010581457267, -1.506920415224914), + Row("c", -0.27238010581457267, -1.506920415224914), + Row("d", -0.27238010581457267, -1.506920415224914), + Row("e", -0.27238010581457267, -1.506920415224914), + Row("f", -0.27238010581457267, -1.506920415224914), + Row("g", -0.27238010581457267, -1.506920415224914), + Row("h", 0.5280049792181881, -1.5000000000000013), + Row("i", 0.5280049792181881, -1.5000000000000013), + Row("j", 0.5280049792181881, -1.5000000000000013))) + } + + test("aggregation function on invalid column") { + val df = Seq((1, "1")).toDF("key", "value") + val e = intercept[AnalysisException]( + df.select($"key", count("invalid").over())) + assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]")) + } + + test("numerical aggregate functions on string column") { + val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") + checkAnswer( + df.select($"key", + var_pop("value1").over(), + variance("value1").over(), + stddev_pop("value1").over(), + stddev("value1").over(), + sum("value1").over(), + mean("value1").over(), + avg("value1").over(), + corr("value1", "value2").over(), + covar_pop("value1", "value2").over(), + covar_samp("value1", "value2").over(), + skewness("value1").over(), + kurtosis("value1").over()), + Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) + } + test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") @@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("b", 2, null, null, null, null, null, null))) } + test("last/first on descending ordered window") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, "v"), + ("b", 1, "k"), + ("b", 2, "l"), + ("b", 3, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order".desc) + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, "v", "v", "v", null, null, "x"), + Row("a", 1, "v", "v", "v", "x", "x", "x"), + Row("a", 2, "v", "v", "v", "y", "y", "y"), + Row("a", 3, "v", "v", "v", "z", "z", "z"), + Row("a", 4, "v", "v", "v", "v", "v", "v"), + Row("b", 1, null, null, "l", "k", "k", "k"), + Row("b", 2, null, null, "l", "l", "l", "l"), + Row("b", 3, null, null, null, null, null, null))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) From e55953b0bf2a80b34127ba123417ee54955a6064 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 19 Apr 2018 15:06:27 -0700 Subject: [PATCH 347/593] [SPARK-24022][TEST] Make SparkContextSuite not flaky ## What changes were proposed in this pull request? SparkContextSuite.test("Cancelling stages/jobs with custom reasons.") could stay in an infinite loop because of the problem found and fixed in [SPARK-23775](https://issues.apache.org/jira/browse/SPARK-23775). This PR solves this mentioned flakyness by removing shared variable usages when cancel happens in a loop and using wait and CountDownLatch for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #21105 from gaborgsomogyi/SPARK-24022. --- .../org/apache/spark/SparkContextSuite.scala | 61 ++++++++----------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b30bd74812b36..ce9f2be1c02dd 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") val REASON = "You shall not pass" - val slices = 10 - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - if (SparkContextSuite.cancelStage) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + for (cancelWhat <- Seq("stage", "job")) { + // This countdown latch used to make sure stage or job canceled in listener + val latch = new CountDownLatch(1) + + val listener = cancelWhat match { + case "stage" => + new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sc.cancelStage(taskStart.stageId, REASON) + latch.countDown() + } } - sc.cancelStage(taskStart.stageId, REASON) - SparkContextSuite.cancelStage = false - SparkContextSuite.semaphore.release(slices) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - if (SparkContextSuite.cancelJob) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + case "job" => + new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sc.cancelJob(jobStart.jobId, REASON) + latch.countDown() + } } - sc.cancelJob(jobStart.jobId, REASON) - SparkContextSuite.cancelJob = false - SparkContextSuite.semaphore.release(slices) - } } - } - sc.addSparkListener(listener) - - for (cancelWhat <- Seq("stage", "job")) { - SparkContextSuite.semaphore.drainPermits() - SparkContextSuite.isTaskStarted = false - SparkContextSuite.cancelStage = (cancelWhat == "stage") - SparkContextSuite.cancelJob = (cancelWhat == "job") + sc.addSparkListener(listener) val ex = intercept[SparkException] { - sc.range(0, 10000L, numSlices = slices).mapPartitions { x => - SparkContextSuite.isTaskStarted = true - // Block waiting for the listener to cancel the stage or job. - SparkContextSuite.semaphore.acquire() + sc.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } x }.count() } @@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } + latch.await(20, TimeUnit.SECONDS) eventually(timeout(20.seconds)) { assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sc.removeSparkListener(listener) } } @@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } object SparkContextSuite { - @volatile var cancelJob = false - @volatile var cancelStage = false @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false From b3fde5a41ee625141b9d21ce32ea68c082449430 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 20 Apr 2018 12:06:41 +0800 Subject: [PATCH 348/593] [SPARK-23877][SQL] Use filter predicates to prune partitions in metadata-only queries ## What changes were proposed in this pull request? This updates the OptimizeMetadataOnlyQuery rule to use filter expressions when listing partitions, if there are filter nodes in the logical plan. This avoids listing all partitions for large tables on the driver. This also fixes a minor bug where the partitions returned from fsRelation cannot be serialized without hitting a stack level too deep error. This is caused by serializing a stream to executors, where the stream is a recursive structure. If the stream is too long, the serialization stack reaches the maximum level of depth. The fix is to create a LocalRelation using an Array instead of the incoming Seq. ## How was this patch tested? Existing tests for metadata-only queries. Author: Ryan Blue Closes #20988 from rdblue/SPARK-23877-metadata-only-push-filters. --- .../execution/OptimizeMetadataOnlyQuery.scala | 94 +++++++++++++------ .../OptimizeHiveMetadataOnlyQuerySuite.scala | 68 ++++++++++++++ 2 files changed, 132 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index dc4aff9f12580..acbd4becb8549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -49,9 +49,9 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(partAttrs)) { + if (a.references.subsetOf(attrs)) { val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -67,7 +67,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic }) } if (isAllDistinctAgg) { - a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation))) + a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters))) } else { a } @@ -98,14 +98,27 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic */ private def replaceTableScanWithPartitionMetadata( child: LogicalPlan, - relation: LogicalPlan): LogicalPlan = { + relation: LogicalPlan, + partFilters: Seq[Expression]): LogicalPlan = { + // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the + // relation's schema. PartitionedRelation ensures that the filters only reference partition cols + val relFilters = partFilters.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + child transform { case plan if plan eq relation => relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(Nil, Nil) - LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) + val partitionData = fsRelation.location.listFiles(relFilters, Nil) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -113,12 +126,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) - val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => + val partitions = if (partFilters.nonEmpty) { + catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + } else { + catalog.listPartitions(relation.tableMeta.identifier) + } + + val partitionData = partitions.map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - LocalRelation(partAttrs, partitionData) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.toArray) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -129,35 +151,47 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes and the table relation node. + * pair of the partition attributes, partition filters, and the table relation node. * * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with * deterministic expressions, and returns result after reaching the partitioned table relation * node. */ - object PartitionedRelation { - - def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => - val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - Some((AttributeSet(partAttrs), l)) - - case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) - Some((AttributeSet(partAttrs), relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, relation) => - if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None - } + object PartitionedRelation extends PredicateHelper { + + def unapply( + plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + plan match { + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) + if fsRelation.partitionSchema.nonEmpty => + val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) + Some((partAttrs, partAttrs, Nil, l)) + + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = AttributeSet( + getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) + Some((partAttrs, partAttrs, Nil, relation)) + + case p @ Project(projectList, child) if projectList.forall(_.deterministic) => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (p.references.subsetOf(attrs)) { + Some((partAttrs, p.outputSet, filters, relation)) + } else { + None + } + } - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, relation) => - if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None - } + case f @ Filter(condition, child) if condition.deterministic => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (f.references.subsetOf(partAttrs)) { + Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) + } else { + None + } + } - case _ => None + case _ => None + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala new file mode 100644 index 0000000000000..95f192f0e40e2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + import spark.implicits._ + + before { + sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") + (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) + } + + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the number of matching partitions + assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5) + + // verify that the partition predicate was pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5) + } + } + + test("SPARK-23877: filter on projected expression") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the matching partitions + val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) + .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) + + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + + // verify that the partition predicate was not pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) + } + } +} From e6b466084c26fbb9b9e50dd5cc8b25da7533ac72 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Fri, 20 Apr 2018 14:58:11 +0900 Subject: [PATCH 349/593] [SPARK-23736][SQL] Extending the concat function to support array columns ## What changes were proposed in this pull request? The PR adds a logic for easy concatenation of multiple array columns and covers: - Concat expression has been extended to support array columns - A Python wrapper ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite - typeCoercion/native/concat.sql ## Codegen examples ### Primitive-type elements ``` val df = Seq( (Seq(1 ,2), Seq(3, 4)), (Seq(1, 2, 3), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 070 */ project_numElements, /* 071 */ 4); /* 072 */ if (project_size > 2147483632) { /* 073 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size + /* 074 */ " bytes of data due to exceeding the limit 2147483632 bytes" + /* 075 */ " for UnsafeArrayData."); /* 076 */ } /* 077 */ /* 078 */ byte[] project_array = new byte[(int)project_size]; /* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData(); /* 080 */ Platform.putLong(project_array, 16, project_numElements); /* 081 */ project_arrayData.pointTo(project_array, 16, (int)project_size); /* 082 */ int project_counter = 0; /* 083 */ for (int y = 0; y < 2; y++) { /* 084 */ for (int z = 0; z < args[y].numElements(); z++) { /* 085 */ if (args[y].isNullAt(z)) { /* 086 */ project_arrayData.setNullAt(project_counter); /* 087 */ } else { /* 088 */ project_arrayData.setInt( /* 089 */ project_counter, /* 090 */ args[y].getInt(z) /* 091 */ ); /* 092 */ } /* 093 */ project_counter++; /* 094 */ } /* 095 */ } /* 096 */ return project_arrayData; /* 097 */ } /* 098 */ }.concat(project_args); /* 099 */ boolean project_isNull = project_value == null; ``` ### Non-primitive-type elements ``` val df = Seq( (Seq("aa" ,"bb"), Seq("ccc", "ddd")), (Seq("x", "y"), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ Object[] project_arrayObjects = new Object[(int)project_numElements]; /* 070 */ int project_counter = 0; /* 071 */ for (int y = 0; y < 2; y++) { /* 072 */ for (int z = 0; z < args[y].numElements(); z++) { /* 073 */ project_arrayObjects[project_counter] = args[y].getUTF8String(z); /* 074 */ project_counter++; /* 075 */ } /* 076 */ } /* 077 */ return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects); /* 078 */ } /* 079 */ }.concat(project_args); /* 080 */ boolean project_isNull = project_value == null; ``` Author: mn-mikke Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master. --- .../spark/unsafe/array/ByteArrayMethods.java | 6 +- python/pyspark/sql/functions.py | 34 +-- .../catalyst/expressions/UnsafeArrayData.java | 10 + .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../expressions/collectionOperations.scala | 220 +++++++++++++++++- .../expressions/stringExpressions.scala | 81 ------- .../CollectionExpressionsSuite.scala | 41 ++++ .../org/apache/spark/sql/functions.scala | 20 +- .../inputs/typeCoercion/native/concat.sql | 62 +++++ .../typeCoercion/native/concat.sql.out | 78 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 74 ++++++ .../sql/execution/command/DDLSuite.scala | 4 +- 13 files changed, 529 insertions(+), 111 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 4bc9955090fd7..ef0f78d95d1ee 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1be68f2a4a448..da32ab25cad0c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1425,21 +1425,6 @@ def hash(*cols): del _name, _doc -@since(1.5) -@ignore_unicode_prefix -def concat(*cols): - """ - Concatenates multiple input columns together into a single column. - If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) - - @since(1.5) @ignore_unicode_prefix def concat_ws(sep, *cols): @@ -1845,6 +1830,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input columns together into a single column. + The function works with strings, binary and compatible array columns. + + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + @since(2.4) def array_position(col, value): """ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8546c28335536..d5d934bc91cab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -56,9 +56,19 @@ public final class UnsafeArrayData extends ArrayData { public static int calculateHeaderPortionInBytes(int numFields) { + return (int)calculateHeaderPortionInBytes((long)numFields); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a44f2d5272b8e..c41f16c61d7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -308,7 +308,6 @@ object FunctionRegistry { expression[BitLength]("bit_length"), expression[Length]("char_length"), expression[Length]("character_length"), - expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), @@ -413,6 +412,7 @@ object FunctionRegistry { expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), + expression[Concat]("concat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 281f206e8d59e..cfcbd8db559a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -520,6 +520,14 @@ object TypeCoercion { case None => a } + case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case None => c + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dba426e999dda..c16793bda028e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** * Given an array or map, returns its size. Returns -1 if null. @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def prettyName: String = "element_at" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + + s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: Nil) + ev.copy(s""" + $initCode + $codes + $javaType ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (code, numElements) + } + + private def nullArgumentProtection() : String = { + if (nullable) { + s""" + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} + """.stripMargin + } else { + "" + } + } + + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + + | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + + | " for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | $unsafeArraySizeInBytes + | byte[] $arrayName = new byte[(int)$arraySizeName]; + | UnsafeArrayData $arrayData = new UnsafeArrayData(); + | Platform.putLong($arrayName, $baseOffset, $numElemName); + | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5a02ca0d6862c..ea005a26a4c8b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * An expression that concatenates multiple inputs into a single output. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * If any input is null, concat returns null. - */ -@ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private lazy val isBinaryMode: Boolean = dataType == BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } - } - - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - if (isBinaryMode) { - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - } else { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } - - val (concatenator, initCode) = if (isBinaryMode) { - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - } else { - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - override def toString: String = s"concat(${children.mkString(", ")})" - - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" -} - - /** * An expression that concatenates multiple input strings or array of strings into a single string, * using a given separator (the first child). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7d8fe211858b2..43c5dda2e4a48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m2, Literal("a")), null) } + + test("Concat") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val ai4 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5)) + checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5)) + checkEvaluation(Concat(Seq(ai4)), null) + checkEvaluation(Concat(Seq(ai0, ai4)), null) + checkEvaluation(Concat(Seq(ai4, ai0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) + val as4 = Literal.create(null, ArrayType(StringType)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + + checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e")) + checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e")) + checkEvaluation(Concat(Seq(as4)), null) + checkEvaluation(Concat(Seq(as0, as4)), null) + checkEvaluation(Concat(Seq(as4, as0)), null) + + checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c8580378303e..bea8c0e445002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2228,16 +2228,6 @@ object functions { */ def base64(e: Column): Column = withExpr { Base64(e.expr) } - /** - * Concatenates multiple input columns together into a single column. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } - /** * Concatenates multiple input string columns together into a single string column, * using the given separator. @@ -3038,6 +3028,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } + /** * Locates the position of the first occurrence of the value in the given array as long. * Returns null if either of the arguments are null. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 0beebec5702fd..db00a18f2e7e9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -91,3 +91,65 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +); + +-- Concatenate arrays of the same type +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays; + +-- Concatenate arrays of different types +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 09729fdc2ec32..62befc5ca0f15 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -237,3 +237,81 @@ struct 78910 891011 9101112 + + +-- !query 11 +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays +-- !query 12 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,data_array:array,timestamp_array:array,string_array:array,array_array:array>,struct_array:array>,map_array:array>> +-- !query 12 output +[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}] + + +-- !query 13 +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays +-- !query 13 schema +struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +-- !query 13 output +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7c976c1b7f915..25e5cd60dd236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("concat function - arrays") { + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + + // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index cbd7f9d6f67be..3998ceca38b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + Row("Usage: concat(col1, col2, ..., colN) - " + + "Returns the concatenation of col1, col2, ..., colN.") :: Nil ) // extended mode checkAnswer( From 074a7f90536493b607e8e74bcebf3a27ea49a49d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 14:43:47 +0200 Subject: [PATCH 350/593] [SPARK-23588][SQL][FOLLOW-UP] Resolve a map builder method per execution in CatalystToExternalMap ## What changes were proposed in this pull request? This pr is a follow-up pr of #20979 and fixes code to resolve a map builder method per execution instead of per row in `CatalystToExternalMap`. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21112 from maropu/SPARK-23588-FOLLOWUP. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420a..32c1f34ef97a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1040,11 +1040,13 @@ case class CatalystToExternalMap private( private lazy val valueConverter = CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") - val module = clazz.getField("MODULE$").get(null) - val method = clazz.getMethod("newBuilder") - method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { From 0dd97f6ea4affde1531dec1bec004b7ab18c6965 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 15:02:27 +0200 Subject: [PATCH 351/593] [SPARK-23595][SQL] ValidateExternalType should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ValidateExternalType`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20757 from maropu/SPARK-23595. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../expressions/objects/objects.scala | 34 ++++++++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 33 ++++++++++++++++-- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a8..f9acc208b715e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 32c1f34ef97a5..f1ffcaec8a484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -1672,13 +1673,36 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. @@ -1691,7 +1715,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0b..7136af8934486 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } } class TestBean extends Serializable { From 1d758dc73b54e802fdc92be204185fe7414e6553 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Apr 2018 10:23:01 -0700 Subject: [PATCH 352/593] Revert "[SPARK-23775][TEST] Make DataFrameRangeSuite not flaky" This reverts commit 0c94e48bc50717e1627c0d2acd5382d9adc73c97. --- .../spark/sql/DataFrameRangeSuite.scala | 78 ++++++++----------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index a0fd74088ce8b..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql -import java.util.concurrent.{CountDownLatch, TimeUnit} - import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -154,53 +152,39 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - // Save and restore the value because SparkContext is shared - val savedInterruptOnCancel = sparkContext - .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - - try { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") - - for (codegen <- Seq(true, false)) { - // This countdown latch used to make sure with all the stages cancelStage called in listener - val latch = new CountDownLatch(2) - - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - sparkContext.cancelStage(taskStart.stageId) - latch.countDown() - } + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + eventually(timeout(10.seconds), interval(1.millis)) { + assert(DataFrameRangeSuite.stageToKill > 0) } + sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + } + } - sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => - x.synchronized { - x.wait() - } - x - }.toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } + sparkContext.addSparkListener(listener) + for (codegen <- Seq(true, false)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + DataFrameRangeSuite.stageToKill = -1 + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1).map { x => + DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() + x + }.toDF("id").agg(sum("id")).collect() } - latch.await(20, TimeUnit.SECONDS) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } - sparkContext.removeSparkListener(listener) } - } finally { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, - savedInterruptOnCancel) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -220,3 +204,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + +object DataFrameRangeSuite { + @volatile var stageToKill = -1 +} From 32b4bcd6d31b92b179a15f9886779fc5f96404b5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 21 Apr 2018 23:14:58 +0800 Subject: [PATCH 353/593] [SPARK-24029][CORE] Set SO_REUSEADDR on listen sockets. This allows sockets to be bound even if there are sockets from a previous application that are still pending closure. It avoids bind issues when, for example, re-starting the SHS. Don't enable the option on Windows though. The following page explains some odd behavior that this option can have there: https://msdn.microsoft.com/en-us/library/windows/desktop/ms740621%28v=vs.85%29.aspx I intentionally ignored server sockets that always bind to ephemeral ports, since those don't benefit from this option. Author: Marcelo Vanzin Closes #21110 from vanzin/SPARK-24029. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 +++- .../org/apache/spark/deploy/rest/RestSubmissionServer.scala | 1 + core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bcc..612750972c4bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,7 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator); + .childOption(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f270..3d99d085408c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0e8a6307de6a8..d6a025a6f12da 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging { connectionFactories: _*) connector.setPort(port) connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads From 7bc853d08973a6bd839ad2222911eb0a0f413677 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Apr 2018 10:45:12 -0700 Subject: [PATCH 354/593] [SPARK-24033][SQL] Fix Mismatched of Window Frame specifiedwindowframe(RowFrame, -1, -1) ## What changes were proposed in this pull request? When the OffsetWindowFunction's frame is `UnaryMinus(Literal(1))` but the specified window frame has been simplified to `Literal(-1)` by some optimizer rules e.g., `ConstantFolding`. Thus, they do not match and cause the following error: ``` org.apache.spark.sql.AnalysisException: Window Frame specifiedwindowframe(RowFrame, -1, -1) must match the required frame specifiedwindowframe(RowFrame, -1, -1); at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:91) at ``` ## How was this patch tested? Added a test Author: gatorsmile Closes #21115 from gatorsmile/fixLag. --- .../catalyst/expressions/windowExpressions.scala | 5 ++++- .../spark/sql/DataFrameWindowFramesSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 78895f1c2f6f5..9fe2fb2b95e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -342,7 +342,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 0ee9b0edc02b2..2a0b2b85e10a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -402,4 +402,18 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } } From c48085aa91c60615a4de3b391f019f46f3fcdbe3 Mon Sep 17 00:00:00 2001 From: Mykhailo Shtelma Date: Sat, 21 Apr 2018 23:33:57 -0700 Subject: [PATCH 355/593] [SPARK-23799][SQL] FilterEstimation.evaluateInSet produces devision by zero in a case of empty table with analyzed statistics >What changes were proposed in this pull request? During evaluation of IN conditions, if the source data frame, is represented by a plan, that uses hive table with columns, which were previously analysed, and the plan has conditions for these fields, that cannot be satisfied (which leads us to an empty data frame), FilterEstimation.evaluateInSet method produces NumberFormatException and ClassCastException. In order to fix this bug, method FilterEstimation.evaluateInSet at first checks, if distinct count is not zero, and also checks if colStat.min and colStat.max are defined, and only in this case proceeds with the calculation. If at least one of the conditions is not satisfied, zero is returned. >How was this patch tested? In order to test the PR two tests were implemented: one in FilterEstimationSuite, that tests the plan with the statistics that violates the conditions mentioned above, and another one in StatisticsCollectionSuite, that test the whole process of analysis/optimisation of the query, that leads to the problems, mentioned in the first section. Author: Mykhailo Shtelma Author: smikesh Closes #21052 from mshtelma/filter_estimation_evaluateInSet_Bugs. --- .../statsEstimation/FilterEstimation.scala | 4 +++ .../FilterEstimationSuite.scala | 11 ++++++++ .../spark/sql/StatisticsCollectionSuite.scala | 28 +++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0538c9d88584b..263c9ba60d145 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,6 +392,10 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 43440d51dede6..16cb5d032cf57 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -357,6 +357,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..b91712f4cc25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -382,4 +382,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } From c3a86faa53c9e49efd595802adc38a6d412ce681 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 Apr 2018 10:45:25 +0800 Subject: [PATCH 356/593] [SPARK-10399][SPARK-23879][FOLLOWUP][CORE] Free unused off-heap memory in MemoryBlockSuite ## What changes were proposed in this pull request? As viirya pointed out [here](https://github.com/apache/spark/pull/19222#discussion_r179910484), this PR explicitly frees unused off-heap memory in `MemoryBlockSuite` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21117 from kiszk/SPARK-10399-free-offheap. --- .../java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 5d5fdc1c55a75..ef5ff8ee70ec0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } catch (Exception expected) { Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); } @Test @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() { int length = 56; check(memory, obj, offset, length); + memoryAllocator.free(memory); long address = Platform.allocateMemory(112); memory = new OffHeapMemoryBlock(address, length); obj = memory.getBaseObject(); offset = memory.getBaseOffset(); check(memory, obj, offset, length); + Platform.freeMemory(address); } } From f70f46d1e5bc503e9071707d837df618b7696d32 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:18:50 +0800 Subject: [PATCH 357/593] [SPARK-23877][SQL][FOLLOWUP] use PhysicalOperation to simplify the handling of Project and Filter over partitioned relation ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/20988 `PhysicalOperation` can collect Project and Filters over a certain plan and substitute the alias with the original attributes in the bottom plan. We can use it in `OptimizeMetadataOnlyQuery` rule to handle the Project and Filter over partitioned relation. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21111 from cloud-fan/refactor. --- .../plans/logical/LocalRelation.scala | 6 ++ .../sql/execution/LocalTableScanExec.scala | 3 + .../execution/OptimizeMetadataOnlyQuery.scala | 58 ++++++------------- .../OptimizeHiveMetadataOnlyQuerySuite.scala | 16 ++++- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b05508db786ad..720d42ab409a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,6 +43,12 @@ object LocalRelation { } } +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c7..448eb703eacde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index acbd4becb8549..3ca03ab2939aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -49,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(attrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -102,7 +107,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic partFilters: Seq[Expression]): LogicalPlan = { // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the // relation's schema. PartitionedRelation ensures that the filters only reference partition cols - val relFilters = partFilters.map { e => + val normalizedFilters = partFilters.map { e => e transform { case a: AttributeReference => a.withName(relation.output.find(_.semanticEquals(a)).get.name) @@ -114,11 +119,8 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(relFilters, Nil) - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -127,7 +129,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitions = if (partFilters.nonEmpty) { - catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) } else { catalog.listPartitions(relation.tableMeta.identifier) } @@ -137,10 +139,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.toArray) + LocalRelation(partAttrs, partitionData) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -151,44 +150,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes, partition filters, and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. + * pair of the partition attributes and the table relation node. */ object PartitionedRelation extends PredicateHelper { - def unapply( - plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => + if fsRelation.partitionSchema.nonEmpty => val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) - Some((partAttrs, partAttrs, Nil, l)) + Some((partAttrs, l)) case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = AttributeSet( getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) - Some((partAttrs, partAttrs, Nil, relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (p.references.subsetOf(attrs)) { - Some((partAttrs, p.outputSet, filters, relation)) - } else { - None - } - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (f.references.subsetOf(partAttrs)) { - Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) - } else { - None - } - } + Some((partAttrs, relation)) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 95f192f0e40e2..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -32,13 +33,22 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto import spark.implicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) } + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the number of matching partitions @@ -50,7 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions From d87d30e4fe9c9e91c462351e9f744a830db8d6fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:21:01 +0800 Subject: [PATCH 358/593] [SPARK-23564][SQL] infer additional filters from constraints for join's children ## What changes were proposed in this pull request? The existing query constraints framework has 2 steps: 1. propagate constraints bottom up. 2. use constraints to infer additional filters for better data pruning. For step 2, it mostly helps with Join, because we can connect the constraints from children to the join condition and infer powerful filters to prune the data of the join sides. e.g., the left side has constraints `a = 1`, the join condition is `left.a = right.a`, then we can infer `right.a = 1` to the right side and prune the right side a lot. However, the current logic of inferring filters from constraints for Join is pretty weak. It infers the filters from Join's constraints. Some joins like left semi/anti exclude output from right side and the right side constraints will be lost here. This PR propose to check the left and right constraints individually, expand the constraints with join condition and add filters to children of join directly, instead of adding to the join condition. This reverts https://github.com/apache/spark/pull/20670 , covers https://github.com/apache/spark/pull/20717 and https://github.com/apache/spark/pull/20816 This is inspired by the original PRs and the tests are all from these PRs. Thanks to the authors mgaido91 maryannxue KaiXinXiaoLei ! ## How was this patch tested? new tests Author: Wenchen Fan Closes #21083 from cloud-fan/join. --- .../sql/catalyst/optimizer/Optimizer.scala | 97 +++++++++---------- .../plans/logical/QueryPlanConstraints.scala | 95 ++++++++---------- .../InferFiltersFromConstraintsSuite.scala | 53 +++++++--- 3 files changed, 124 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 913354e4df0e6..f00d40d11f23f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,13 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * In addition, for left/right outer joins, infer predicate from the preserved side of the Join - * operator and push the inferred filter over to the null-supplying side. For example, if the - * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in - * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will - * be applied to the null-supplying side. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.allConstraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt - case None => - additionalConstraints.reduceOption(And) - } - // Infer filter for left/right outer joins - val newLeftOpt = joinType match { - case RightOuter if newConditionOpt.isDefined => - val inferredConstraints = left.getRelevantConstraints( - left.constraints - .union(right.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(left.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, left)) - case _ => None - } - val newRightOpt = joinType match { - case LeftOuter if newConditionOpt.isDefined => - val inferredConstraints = right.getRelevantConstraints( - right.constraints - .union(left.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(right.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, right)) - case _ => None - } + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) - if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) - || newLeftOpt.isDefined || newRightOpt.isDefined) { - Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) - } else { - join + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a29f3d29236c7..cc352c59dff80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains an additional set of constraints, such as equality - * constraints and `isNotNull` constraints, etc. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val allConstraints: ExpressionSet = { + lazy val constraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } } - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) - /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { /** - * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as - * equality constraints and `isNotNull` constraints, etc., and that only contains references - * to this [[LogicalPlan]] node. + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. */ - def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { - val allRelevantConstraints = - if (conf.constraintPropagationEnabled) { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - } else { - constraints - } - ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints } + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) - - private def selfReferenceOnly(e: Expression): Boolean = { - e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e068f51044589..e4671f0d1cce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-21479: Outer join no filter push down to preserved side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a) && 'a === 1) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } From afbdf427302aba858f95205ecef7667f412b2a6a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 23 Apr 2018 14:28:28 +0200 Subject: [PATCH 359/593] [SPARK-23589][SQL] ExternalMapToCatalyst should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ExternalMapToCatalyst`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20980 from maropu/SPARK-23589. --- .../expressions/objects/objects.scala | 60 +++++++++- .../expressions/ObjectExpressionsSuite.scala | 108 +++++++++++++++++- 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f1ffcaec8a484..9c7e76467d153 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,8 +1255,64 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 7136af8934486..730b36c32333c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -501,6 +502,111 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(1))), "java.lang.Integer is not a valid external type for schema of double") } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } } class TestBean extends Serializable { From 293a0f29e314dc532cec2048a7c6bc00e31de472 Mon Sep 17 00:00:00 2001 From: Teng Peng Date: Mon, 23 Apr 2018 10:29:47 -0700 Subject: [PATCH 360/593] [Spark-24024][ML] Fix poisson deviance calculations in GLM to handle y = 0 ## What changes were proposed in this pull request? It is reported by Spark users that the deviance calculation for poisson regression does not handle y = 0. Thus, the correct model summary cannot be obtained. The user has confirmed the the issue is in ``` override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (y * math.log(y / mu) - (y - mu)) } when y = 0. ``` The user also mentioned there are many other places he believe we should check the same thing. However, no other changes are needed, including Gamma distribution. ## How was this patch tested? Add a comparison with R deviance calculation to the existing unit test. Author: Teng Peng Closes #21125 from tengpeng/Spark24024GLM. --- .../ml/regression/GeneralizedLinearRegression.scala | 10 +++++----- .../regression/GeneralizedLinearRegressionSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 9f1f2405c428e..4c3f1431d5077 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -471,6 +471,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -725,10 +729,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +783,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d5bcbb221783e..997c50157dcda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } From 448d248f897fa39cfc82d71a3d6b67e6470f8a02 Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Mon, 23 Apr 2018 13:56:11 -0500 Subject: [PATCH 361/593] [SPARK-21168] KafkaRDD should always set kafka clientId. [https://issues.apache.org/jira/browse/SPARK-21168](https://issues.apache.org/jira/browse/SPARK-21168) There are no a number of other places that a client ID should be set,and I think we should use consumer.clientId in the clientId method,because the fetch request will be used by the same consumer behind. Author: liuzhaokun Closes #19887 from liu-zhaokun/master1205. --- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..791cf0efaf888 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) From 770add81c3474e754867d7105031a5eaf27159bd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Apr 2018 13:20:32 -0700 Subject: [PATCH 362/593] [SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A structured streaming query with a streaming aggregation can throw the following error in rare cases.  ``` java.lang.IllegalStateException: Cannot commit after already committed or aborted at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643) at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359) at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336) ``` This can happen when the following conditions are accidentally hit.  - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.).  - Query running in `update}` mode - After the shuffle, a partition has exactly 128 records.  This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`. ## How was this patch tested? Added unit test that I have confirm will fail without the fix. Author: Tathagata Das Closes #21124 from tdas/SPARK-23004. --- .../streaming/statefulOperators.scala | 40 +++++++++---------- .../streaming/StreamingAggregationSuite.scala | 25 ++++++++++++ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f9..c9354ac0ec78a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -340,37 +340,35 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1cae8cb8d47f1..382da13430781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From e82cb68349b785c1b35bcfb85bff3a8ec2c93fee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 23 Apr 2018 13:23:02 -0700 Subject: [PATCH 363/593] [SPARK-11237][ML] Add pmml export for k-means in Spark ML ## What changes were proposed in this pull request? Adding PMML export to Spark ML's KMeans Model. ## How was this patch tested? New unit test for Spark ML PMML export based on the old Spark MLlib unit test. Author: Holden Karau Closes #20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans. --- .../org.apache.spark.ml.util.MLFormatRegister | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 75 ++++++++++++------- .../ml/regression/LinearRegression.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 32 +++++++- 4 files changed, 83 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 5e5484fd8784d..f14431d50feec 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1,2 +1,4 @@ org.apache.spark.ml.regression.InternalLinearRegressionModelWriter -org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 987a4285ebad4..1ad157a695a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -152,14 +154,14 @@ class KMeansModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +187,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f67d9d831f327..9cdd3a051e719 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter /** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter - extends MLWriterFormat with MLFormatRegister { + extends MLWriterFormat with MLFormatRegister { override def format(): String = "pmml" diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32830b39407ad..77c9d482d95b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering import scala.util.Random +import org.dmg.pmml.{ClusteringModel, PMML} + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header descripiton is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { From c8f3ac69d176bd10b8de1c147b6903a247943d51 Mon Sep 17 00:00:00 2001 From: wuyi Date: Mon, 23 Apr 2018 15:35:45 -0500 Subject: [PATCH 364/593] [SPARK-23888][CORE] correct the comment of hasAttemptOnHost() TaskSetManager.hasAttemptOnHost had a misleading comment. The comment said that it only checked for running tasks, but really it checked for any tasks that might have run in the past as well. This updates to line up with the implementation. Author: wuyi Closes #20998 from Ngone51/SPARK-23888. --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d958658527f6d..8a96a7692f614 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -287,7 +287,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } From 428b903859c3d8873045fdcfffdebe24fc6e027f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Apr 2018 09:10:29 +0800 Subject: [PATCH 365/593] [SPARK-24029][CORE] Follow up: set SO_REUSEADDR on the server socket. "childOption" is for the remote connections, not for the server socket that actually listens for incoming connections. Author: Marcelo Vanzin Closes #21132 from vanzin/SPARK-24029.2. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 612750972c4bb..60f51125c07fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -99,8 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) + .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); From 281c1ca0dc96b0441a60c32df3d16fbb1c61e99f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Apr 2018 10:11:09 +0800 Subject: [PATCH 366/593] [SPARK-23973][SQL] Remove consecutive Sorts ## What changes were proposed in this pull request? In SPARK-23375 we introduced the ability of removing `Sort` operation during query optimization if the data is already sorted. In this follow-up we remove also a `Sort` which is followed by another `Sort`: in this case the first sort is not needed and can be safely removed. The PR starts from henryr's comment: https://github.com/apache/spark/pull/20560#discussion_r180601594. So credit should be given to him. ## How was this patch tested? added UT Author: Marco Gaido Closes #21072 from mgaido91/SPARK-23973. --- .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++- .../optimizer/RemoveRedundantSortsSuite.scala | 51 ++++++++++++++++--- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f00d40d11f23f..45f13956a0a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -767,12 +767,29 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operation if the child is already sorted + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators */ object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 2319ab8046e56..dae5e6f3ee3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.select('a).analyze + val correctAnswer = orderedPlan.limit(2).select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest { val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } } From c303b1b6766a3dc5961713f98f62cd7d7ac7972a Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 24 Apr 2018 16:16:07 +0800 Subject: [PATCH 367/593] [MINOR][DOCS] Fix comments of SQLExecution#withExecutionId ## What changes were proposed in this pull request? Fix comment. Change `BroadcastHashJoin.broadcastFuture` to `BroadcastExchangeExec.relationFuture`: https://github.com/apache/spark/blob/d28d5732ae205771f1f443b15b10e64dcffb5ff0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala#L66 ## How was this patch tested? N/A Author: seancxmao Closes #21113 from seancxmao/SPARK-13136. --- .../scala/org/apache/spark/sql/execution/SQLExecution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bde..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -88,7 +88,7 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) From 87e8a572be14381da9081365d9aa2cbf3253a32c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Apr 2018 16:18:20 +0800 Subject: [PATCH 368/593] [SPARK-24054][R] Add array_position function / element_at functions ## What changes were proposed in this pull request? This PR proposes to add array_position and element_at in R side too. array_position: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_position(mutated$v1, 1))) ``` ``` array_position(v1, 1.0) 1 2 2 2 3 2 4 3 5 0 6 3 ``` element_at: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, element_at(mutated$v1, 1))) ``` ``` element_at(v1, 1.0) 1 21.0 2 21.0 3 22.8 4 21.4 5 18.7 6 18.1 ``` ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_map(df$model, df$cyl)) head(select(mutated, element_at(mutated$v1, "Valiant"))) ``` ``` element_at(v3, Valiant) 1 NA 2 NA 3 NA 4 NA 5 NA 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21130 from HyukjinKwon/sparkr_array_position_element_at. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 42 +++++++++++++++++++++++++-- R/pkg/R/generics.R | 8 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++-- 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50ea10482..55dec177ea853 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_position", "asc", "ascii", "asin", @@ -245,6 +246,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426b19674..7b3aa05074563 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,11 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. @@ -201,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -208,7 +214,8 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} NULL #' Window functions for Column operations @@ -2975,7 +2982,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2986,6 +2992,22 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3012,6 +3034,22 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff1a3d76..f30ac9e4295e4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469ffc242..a384997830276 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,17 +1479,23 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_position(), element_at() and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - # Test map_keys() and map_values() + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) @@ -1497,6 +1503,9 @@ test_that("column functions", { result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) From 4926a7c2f0a47b562f99dbb4f1ca17adb3192061 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 24 Apr 2018 17:52:05 +0200 Subject: [PATCH 369/593] [SPARK-23589][SQL][FOLLOW-UP] Reuse InternalRow in ExternalMapToCatalyst eval ## What changes were proposed in this pull request? This pr is a follow-up of #20980 and fixes code to reuse `InternalRow` for converting input keys/values in `ExternalMapToCatalyst` eval. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21137 from maropu/SPARK-23589-FOLLOWUP. --- .../expressions/objects/objects.scala | 92 ++++++++++--------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9c7e76467d153..f974fd81fc788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,53 +1255,61 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { - case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[java.util.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - val iter = data.entrySet().iterator() - var i = 0 - while (iter.hasNext) { - val entry = iter.next() - val (key, value) = (entry.getKey, entry.getValue) - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } - case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[scala.collection.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - var i = 0 - for ((key, value) <- data) { - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } + } } override def eval(input: InternalRow): Any = { From 55c4ca88a3b093ee197a8689631be8d1fac1f10f Mon Sep 17 00:00:00 2001 From: Julien Cuquemelle Date: Tue, 24 Apr 2018 10:56:55 -0500 Subject: [PATCH 370/593] [SPARK-22683][CORE] Add a executorAllocationRatio parameter to throttle the parallelism of the dynamic allocation ## What changes were proposed in this pull request? By default, the dynamic allocation will request enough executors to maximize the parallelism according to the number of tasks to process. While this minimizes the latency of the job, with small tasks this setting can waste a lot of resources due to executor allocation overhead, as some executor might not even do any work. This setting allows to set a ratio that will be used to reduce the number of target executors w.r.t. full parallelism. The number of executors computed with this setting is still fenced by `spark.dynamicAllocation.maxExecutors` and `spark.dynamicAllocation.minExecutors` ## How was this patch tested? Units tests and runs on various actual workloads on a Yarn Cluster Author: Julien Cuquemelle Closes #19881 from jcuquemelle/AddTaskPerExecutorSlot. --- .../spark/ExecutorAllocationManager.scala | 24 +++++++++++--- .../spark/internal/config/package.scala | 4 +++ .../ExecutorAllocationManagerSuite.scala | 33 +++++++++++++++++++ docs/configuration.md | 18 ++++++++++ 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c045..aa363eeffffb8 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -69,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -116,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -209,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -273,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 99d779fb600e8..6bb98c37b4479 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -126,6 +126,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d4..3cfb0a9feb32b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -145,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get diff --git a/docs/configuration.md b/docs/configuration.md index 4d4d0c58dd07d..fb02d7ea1d4ea 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1753,6 +1753,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio
    spark.dynamicAllocation.executorAllocationRatio1 + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overriden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings +
    spark.dynamicAllocation.schedulerBacklogTimeout 1s
    Property NameDefaultMeaning
    spark.sql.orc.implhivenative The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1.
    spark.yarn.am.waitTime 100s - In cluster mode, time for the YARN Application Master to wait for the - SparkContext to be initialized. In client mode, time for the YARN Application Master to wait - for the driver to connect to it. + Only used in cluster mode. Time for the YARN Application Master to wait for the + SparkContext to be initialized.
    1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.2. + options are 0.12.0 through 2.3.3.
    queryTimeout + The number of seconds the driver will wait for a Statement object to execute to the given + number of seconds. Zero means there is no limit. In the write path, this option depends on + how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver + checks the timeout of each query instead of an entire JDBC batch. + It defaults to 0. +
    fetchsize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaff..917f0cb221412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the - * number of rows per fetch. + * number of rows per fetch and "queryTimeout" can be used to wait + * for a Statement object to execute to the given number of seconds. * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d9..a73a97c06fe5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -89,6 +89,10 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) + // the number of seconds the driver will wait for a Statement object to execute to the given + // number of seconds. Zero means there is no limit. + val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt + // ------------------------------------------------------------ // Optional parameters only for reading // ------------------------------------------------------------ @@ -160,6 +164,7 @@ object JDBCOptions { val JDBC_LOWER_BOUND = newOption("lowerBound") val JDBC_UPPER_BOUND = newOption("upperBound") val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05326210f3242..0bab3689e5d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -57,6 +57,7 @@ object JDBCRDD extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { + statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) @@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD( val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql") try { + statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() @@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD( stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..f8c5677ea0f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider saveTable(df, tableSchema, isCaseSensitive, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it - dropTable(conn, options.table) + dropTable(conn, options.table, options) createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1b..433443007cfd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -76,6 +76,7 @@ object JdbcUtils extends Logging { Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() } finally { statement.close() @@ -86,9 +87,10 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(s"DROP TABLE $table") } finally { statement.close() @@ -102,6 +104,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() @@ -254,6 +257,7 @@ object JdbcUtils extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) } catch { case _: SQLException => None @@ -596,7 +600,8 @@ object JdbcUtils extends Logging { insertStmt: String, batchSize: Int, dialect: JdbcDialect, - isolationLevel: Int): Iterator[Byte] = { + isolationLevel: Int, + options: JDBCOptions): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -637,6 +642,9 @@ object JdbcUtils extends Logging { try { var rowCount = 0 + + stmt.setQueryTimeout(options.queryTimeout) + while (iterator.hasNext) { val row = iterator.next() var i = 0 @@ -819,7 +827,8 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, + options) ) } @@ -841,6 +850,7 @@ object JdbcUtils extends Logging { val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(sql) } finally { statement.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a699..bc2aca65e803f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1190,4 +1190,20 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").schema === schema) } } + + test("SPARK-23856 Spark jdbc setQueryTimeout option") { + val numJoins = 100 + val longRunningQuery = + s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " + + s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}" + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbtable", s"($longRunningQuery)") + .option("queryTimeout", 1) + .load() + val errMsg = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc82879..1c2c92d1f0737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -515,4 +515,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(e.contains("NULL not allowed for column \"NAME\"")) } + + ignore("SPARK-23856 Spark jdbc setQueryTimeout option") { + // The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API + // `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple + // INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout + // of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver + // checks the timeout of an entire batch in a driver side. So, the test below fails because + // this test suite depends on the h2 JDBC driver and the JDBC write path internally + // uses `executeBatch`. + val errMsg = intercept[SparkException] { + spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write + .mode(SaveMode.Overwrite) + .option("queryTimeout", 1) + .option("batchsize", Int.MaxValue) + .jdbc(url1, "TEST.TIMEOUTTEST", properties) + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } From 710e4e81a8efc1aacc14283fb57bc8786146f885 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 18 May 2018 14:37:01 -0700 Subject: [PATCH 522/593] [SPARK-24308][SQL] Handle DataReaderFactory to InputPartition rename in left over classes ## What changes were proposed in this pull request? SPARK-24073 renames DataReaderFactory -> InputPartition and DataReader -> InputPartitionReader. Some classes still reflects the old name and causes confusion. This patch renames the left over classes to reflect the new interface and fixes a few comments. ## How was this patch tested? Existing unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21355 from arunmahadevan/SPARK-24308. --- .../spark/sql/kafka010/KafkaContinuousReader.scala | 6 +++--- .../spark/sql/kafka010/KafkaMicroBatchReader.scala | 4 ++-- .../sql/kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sources/v2/reader/ContinuousInputPartition.java | 4 ++-- .../spark/sql/sources/v2/reader/InputPartition.java | 6 +++--- .../sql/sources/v2/reader/InputPartitionReader.java | 6 +++--- .../sql/execution/datasources/v2/DataSourceRDD.scala | 6 +++--- .../continuous/ContinuousRateStreamSource.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 4 ++-- .../streaming/sources/ContinuousMemoryStream.scala | 12 ++++++------ .../sources/RateStreamMicroBatchReader.scala | 4 ++-- .../streaming/sources/RateStreamProviderSuite.scala | 2 +- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 88abf8a8dd027..badaa69cc303c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -106,7 +106,7 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousDataReaderFactory( + KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) .asInstanceOf[InputPartition[UnsafeRow]] }.asJava @@ -146,7 +146,7 @@ class KafkaContinuousReader( } /** - * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. @@ -156,7 +156,7 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousDataReaderFactory( +case class KafkaContinuousInputPartition( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a377738ea782..64ba98762788c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -143,7 +143,7 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges val factories = offsetRanges.map { range => - new KafkaMicroBatchDataReaderFactory( + new KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava @@ -300,7 +300,7 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReaderFactory( +private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 871f9700cd1db..c6412eac97dba 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -679,7 +679,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) val factories = reader.planUnsafeInputPartitions().asScala - .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index c24f3b21eade1..dcb87715d0b6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -27,9 +27,9 @@ @InterfaceStability.Evolving public interface ContinuousInputPartition extends InputPartition { /** - * Create a DataReader with particular offset as its startOffset. + * Create an input partition reader with particular offset as its startOffset. * - * @param offset offset want to set as the DataReader's startOffset. + * @param offset offset want to set as the input partition reader's startOffset. */ InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 3524481784fea..f53687e113ae0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -36,8 +36,8 @@ public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this partition can run faster, - * but Spark does not guarantee to run the data reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run faster, + * but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * @@ -53,7 +53,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work. + * Returns an input partition reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 1b7051f1ad0af..f0d808536207a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,11 +23,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for * outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 1a6b32429313a..8d6fb3820d420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,12 +29,12 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: I class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[InputPartition[T]]) + @transient private val inputPartitions: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 8d25d9ccc43d3..516a563bdcc7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -85,7 +85,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousDataReaderFactory( + RateStreamContinuousInputPartition( start.value, start.runTimeMs, i, @@ -113,7 +113,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) } -case class RateStreamContinuousDataReaderFactory( +case class RateStreamContinuousInputPartition( startValue: Long, startTimeMs: Long, partitionIndex: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index daa2963220aef..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -201,7 +201,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) +class MemoryStreamInputPartition(records: Array[UnsafeRow]) extends InputPartition[UnsafeRow] { override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { new InputPartitionReader[UnsafeRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 4daafa65850de..d1c3498450096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.RpcUtils * * ContinuousMemoryStream maintains a list of records for each partition. addData() will * distribute records evenly-ish across partitions. * * RecordEndpoint is set up as an endpoint for executor-side - * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified - * offset within the list, or null if that offset doesn't yet have a record. + * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { @@ -106,7 +106,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa startOffset.partitionNums.map { case (part, index) => - new ContinuousMemoryStreamDataReaderFactory( + new ContinuousMemoryStreamInputPartition( endpointName, part, index): InputPartition[Row] }.toList.asJava } @@ -157,9 +157,9 @@ object ContinuousMemoryStream { } /** - * Data reader factory for continuous memory stream. + * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamDataReaderFactory( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, startOffset: Int) extends InputPartition[Row] { @@ -168,7 +168,7 @@ class ContinuousMemoryStreamDataReaderFactory( } /** - * Data reader for continuous memory stream. + * An input partition reader for continuous memory stream. * * Polls the driver endpoint for new records. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 723cc3ad5bb89..fbff8db987110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -167,7 +167,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: } (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( + new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) : InputPartition[Row] }.toList.asJava @@ -182,7 +182,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchDataReaderFactory( +class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 39a010f970ce5..bf72e5c99689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -309,7 +309,7 @@ class RateSourceSuite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => + case t: RateStreamContinuousInputPartition => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) From 434d74e337465d77fa49ab65e2b5461e5ff7b5c7 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Fri, 18 May 2018 16:54:39 -0700 Subject: [PATCH 523/593] [SPARK-23503][SS] Enforce sequencing of committed epochs for Continuous Execution ## What changes were proposed in this pull request? Made changes to EpochCoordinator so that it enforces a commit order. In case a message for epoch n is lost and epoch (n + 1) is ready for commit before epoch n is, epoch (n + 1) will wait for epoch n to be committed first. ## How was this patch tested? Existing tests in ContinuousSuite and EpochCoordinatorSuite. Author: Efim Poberezkin Closes #20936 from efimpoberezkin/pr/sequence-commited-epochs. --- .../continuous/EpochCoordinator.scala | 69 +++++++++++++++---- .../continuous/EpochCoordinatorSuite.scala | 6 +- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..8877ebeb26735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator( private val partitionOffsets = mutable.Map[(Long, Int), PartitionOffset]() + private var lastCommittedEpoch = startEpoch - 1 + // Remembers epochs that have to wait for previous epochs to be committed first. + private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long] + private def resolveCommitsAtEpoch(epoch: Long) = { - val thisEpochCommits = - partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val thisEpochCommits = findPartitionCommitsForEpoch(epoch) val nextEpochOffsets = partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochCommits.size == numWriterPartitions && nextEpochOffsets.size == numReaderPartitions) { - logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") - // Sequencing is important here. We must commit to the writer before recording the commit - // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, thisEpochCommits.toArray) - query.commit(epoch) - - // Cleanup state from before this epoch, now that we know all partitions are forever past it. - for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) - } - for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + + // Check that last committed epoch is the previous one for sequencing of committed epochs. + // If not, add the epoch being currently processed to epochs waiting to be committed, + // otherwise commit it. + if (lastCommittedEpoch != epoch - 1) { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is waiting for epoch ${epoch - 1} to be committed first.") + epochsWaitingToBeCommitted.add(epoch) + } else { + commitEpoch(epoch, thisEpochCommits) + lastCommittedEpoch = epoch + + // Commit subsequent epochs that are waiting to be committed. + var nextEpoch = lastCommittedEpoch + 1 + while (epochsWaitingToBeCommitted.contains(nextEpoch)) { + val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch) + commitEpoch(nextEpoch, nextEpochCommits) + + epochsWaitingToBeCommitted.remove(nextEpoch) + lastCommittedEpoch = nextEpoch + nextEpoch += 1 + } + + // Cleanup state from before last committed epoch, + // now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionOffsets.remove(k) + } } } } + /** + * Collect per-partition commits for an epoch. + */ + private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = { + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + } + + /** + * Commit epoch to the offset log. + */ + private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is ready to be committed. Committing epoch $epoch.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, messages.toArray) + query.commit(epoch) + } + override def receive: PartialFunction[Any, Unit] = { // If we just drop these messages, we won't do any writes to the query. The lame duck tasks // won't shed errors or anything. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 99e30561f81d5..82836dced9df7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -120,7 +120,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + test("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { setWriterPartitions(2) setReaderPartitions(2) @@ -141,7 +141,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -162,7 +162,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) From dd37529a8dada6ed8a49b8ce50875268f6a20cba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 19 May 2018 18:51:02 +0800 Subject: [PATCH 524/593] [SPARK-24250][SQL] support accessing SQLConf inside tasks ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan Closes #21299 from cloud-fan/config. --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 50 ++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 9 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..643e4c686f58d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,13 +27,12 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -107,7 +106,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1297,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1764,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..032525a08ccdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,37 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..d54bfbfc14f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..404d6313ab92c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadonlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From 000e25ae7950ff005d4bbe4fffed410e5947075c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 20 May 2018 16:13:42 +0800 Subject: [PATCH 525/593] Revert "[SPARK-24250][SQL] support accessing SQLConf inside tasks" This reverts commit dd37529a8dada6ed8a49b8ce50875268f6a20cba. --- .../org/apache/spark/TaskContextImpl.scala | 2 - .../spark/sql/internal/ReadOnlySQLConf.scala | 66 ------------------- .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +----- .../spark/sql/execution/SQLExecution.scala | 50 ++++---------- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 ------------------- 9 files changed, 36 insertions(+), 210 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 0791fe856ef15..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,6 +178,4 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException - // TODO: shall we publish it and define it in `TaskContext`? - private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala deleted file mode 100644 index 19f67236c8979..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.internal - -import java.util.{Map => JMap} - -import org.apache.spark.{TaskContext, TaskContextImpl} -import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} - -/** - * A readonly SQLConf that will be created by tasks running at the executor side. It reads the - * configs from the local properties which are propagated from driver to executors. - */ -class ReadOnlySQLConf(context: TaskContext) extends SQLConf { - - @transient override val settings: JMap[String, String] = { - context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] - } - - @transient override protected val reader: ConfigReader = { - new ConfigReader(new TaskContextConfigProvider(context)) - } - - override protected def setConfWithCheck(key: String, value: String): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def unsetConf(key: String): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def unsetConf(entry: ConfigEntry[_]): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def clear(): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def clone(): SQLConf = { - throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") - } - - override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { - throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") - } -} - -class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { - override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 643e4c686f58d..53a50305348fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,12 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -106,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (TaskContext.get != null) { - new ReadOnlySQLConf(TaskContext.get()) - } else { - confGetter.get()() - } - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1297,11 +1292,17 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient protected val reader = new ConfigReader(settings) + @transient private val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1764,7 +1765,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - protected def setConfWithCheck(key: String, value: String): Unit = { + private def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 032525a08ccdb..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,18 +68,16 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sc.getCallSite() + val callSite = sparkSession.sparkContext.getCallSite() - withSQLConfPropagated(sparkSession) { - sc.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sc.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { executionIdToQueryExecution.remove(executionId) @@ -92,37 +90,13 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { - val sc = sparkSession.sparkContext + def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - withSQLConfPropagated(sparkSession) { - try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) - body - } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) - } - } - } - - def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - // Set all the specified SQL configs to local properties, so that they can be available at - // the executor side. - val allConfigs = sparkSession.sessionState.conf.getAllConfs - val originalLocalProps = allConfigs.collect { - case (key, value) if key.startsWith("spark") => - val originalValue = sc.getLocalProperty(key) - sc.setLocalProperty(key, value) - (key, originalValue) - } - try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - for ((key, value) <- originalLocalProps) { - sc.setLocalProperty(key, value) - } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d54bfbfc14f5f..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withExecutionId(sparkContext, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..ba83df0efebd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,7 +34,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -105,19 +104,22 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - SQLExecution.withSQLConfPropagated(json.sparkSession) { - JsonInferSchema.infer(rdd, parsedOptions, rowParser) - } + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = inputPaths.map(_.getPath.toString), + paths = paths, className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -163,9 +165,7 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - SQLExecution.withSQLConfPropagated(sparkSession) { - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) - } + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9e0ec9481b0de..daea6c39624d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala deleted file mode 100644 index 404d6313ab92c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.internal - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.test.SQLTestUtils - -class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { - import testImplicits._ - - protected var spark: SparkSession = null - - // Create a new [[SparkSession]] running in local-cluster mode. - override def beforeAll(): Unit = { - super.beforeAll() - spark = SparkSession.builder() - .master("local-cluster[2,1,1024]") - .appName("testing") - .getOrCreate() - } - - override def afterAll(): Unit = { - spark.stop() - spark = null - } - - test("ReadonlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => - val conf = SQLConf.get - Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") - }.collect() - assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") - } - } - - test("case-sensitive config should work for json schema inference") { - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - withTempPath { path => - val pathString = path.getCanonicalPath - spark.range(10).select('id.as("ID")).write.json(pathString) - spark.range(10).write.mode("append").json(pathString) - assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) - } - } - } -} From 8eac621229b50e15bea550a751593bba0bf8b20c Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 20 May 2018 18:15:04 -0500 Subject: [PATCH 526/593] [SPARK-23857][MESOS] remove keytab check in mesos cluster mode at first submit time ## What changes were proposed in this pull request? - Removes the check for the keytab when we are running in mesos cluster mode. - Keeps the check for client mode since in cluster mode we eventually launch the driver within the cluster in client mode. In the latter case we want to have the check done when the container starts, the keytab should be checked if it exists within the container's local filesystem. ## How was this patch tested? This was manually tested by running spark submit in mesos cluster mode. Author: Stavros Closes #20967 from skonto/fix_mesos_keytab_susbmit. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 087e9c31a9c9a..4baf032f0e9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -310,6 +310,7 @@ private[spark] class SparkSubmit extends Logging { val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isMesosClient = clusterManager == MESOS && deployMode == CLIENT if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -337,7 +338,7 @@ private[spark] class SparkSubmit extends Logging { val targetDir = Utils.createTempDir() // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) { if (args.principal != null) { if (args.keytab != null) { require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") From f32b7faf7c4b5d2ac45a2db96935f67d1b629ca2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 21 May 2018 09:47:52 +0800 Subject: [PATCH 527/593] [MINOR][PROJECT-INFRA] Check if 'original_head' variable is defined in clean_up at merge script ## What changes were proposed in this pull request? This PR proposes to check if global variable exists or not in clean_up. This can happen when it fails at: https://github.com/apache/spark/blob/7013eea11cb32b1e0038dc751c485da5c94a484b/dev/merge_spark_pr.py#L423 I found this (It was my environment problem) but the error message took me a while to debug. ## How was this patch tested? Manually tested: **Before** ``` git rev-parse --abbrev-ref HEAD fatal: Not a git repository (or any of the parent directories): .git Traceback (most recent call last): File "./dev/merge_spark_pr_jira.py", line 517, in clean_up() File "./dev/merge_spark_pr_jira.py", line 104, in clean_up print("Restoring head pointer to %s" % original_head) NameError: global name 'original_head' is not defined ``` **After** ``` git rev-parse --abbrev-ref HEAD fatal: Not a git repository (or any of the parent directories): .git Traceback (most recent call last): File "./dev/merge_spark_pr.py", line 516, in main() File "./dev/merge_spark_pr.py", line 424, in main original_head = get_current_ref() File "./dev/merge_spark_pr.py", line 412, in get_current_ref ref = run_cmd("git rev-parse --abbrev-ref HEAD").strip() File "./dev/merge_spark_pr.py", line 94, in run_cmd return subprocess.check_output(cmd.split(" ")) File "/usr/local/Cellar/python2/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/subprocess.py", line 219, in check_output raise CalledProcessError(retcode, cmd, output=output) subprocess.CalledProcessError: Command '['git', 'rev-parse', '--abbrev-ref', 'HEAD']' returned non-zero exit status 128 ``` Author: hyukjinkwon Closes #21349 from HyukjinKwon/minor-merge-script. --- dev/merge_spark_pr.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 5ea205fbed4aa..7f46a1c8f6a7c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -101,14 +101,15 @@ def continue_maybe(prompt): def clean_up(): - print("Restoring head pointer to %s" % original_head) - run_cmd("git checkout %s" % original_head) + if 'original_head' in globals(): + print("Restoring head pointer to %s" % original_head) + run_cmd("git checkout %s" % original_head) - branches = run_cmd("git branch").replace(" ", "").split("\n") + branches = run_cmd("git branch").replace(" ", "").split("\n") - for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): - print("Deleting local branch %s" % branch) - run_cmd("git branch -D %s" % branch) + for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): + print("Deleting local branch %s" % branch) + run_cmd("git branch -D %s" % branch) # merge the requested PR and return the merge hash From 6d7d45a1af078edd9e4ed027e735d6096482179c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 May 2018 15:39:35 +0800 Subject: [PATCH 528/593] [SPARK-24242][SQL] RangeExec should have correct outputOrdering and outputPartitioning ## What changes were proposed in this pull request? Logical `Range` node has been added with `outputOrdering` recently. It's used to eliminate redundant `Sort` during optimization. However, this `outputOrdering` doesn't not propagate to physical `RangeExec` node. We also add correct `outputPartitioning` to `RangeExec` node. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21291 from viirya/SPARK-24242. --- python/pyspark/sql/tests.py | 4 +-- .../execution/basicPhysicalOperators.scala | 14 ++++++++++ .../spark/sql/ConfigBehaviorSuite.scala | 4 ++- .../spark/sql/execution/PlannerSuite.scala | 27 ++++++++++++++++++- .../execution/WholeStageCodegenSuite.scala | 4 +-- .../sql/execution/debug/DebuggingSuite.scala | 7 +++-- 6 files changed, 52 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a1b6db71782bb..c7bd8f01b907f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5239,8 +5239,8 @@ def test_complex_groupby(self): expected2 = df.groupby().agg(sum(df.v)) # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..2df81d09c58e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output + override def outputOrdering: Seq[SortOrder] = range.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numElements > 0) { + if (numSlices == 1) { + SinglePartition + } else { + RangePartitioning(outputOrdering, numSlices) + } + } else { + UnknownPartitioning(0) + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 949505e449fd7..276496be3d62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id.desc) + // Range has range partitioning in its output now. To have a range shuffle, we + // need to run a repartition first. + val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a375f881c7d63..b2aba8e72c5db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -633,6 +633,31 @@ class PlannerSuite extends SharedSQLContext { requiredOrdering = Seq(orderingA, orderingB), shouldHaveSort = true) } + + test("SPARK-24242: RangeExec should have correct output ordering and partitioning") { + val df = spark.range(10) + val rangeExec = df.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + val range = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(rangeExec.head.outputOrdering == range.head.outputOrdering) + assert(rangeExec.head.outputPartitioning == + RangePartitioning(rangeExec.head.outputOrdering, df.rdd.getNumPartitions)) + + val rangeInOnePartition = spark.range(1, 10, 1, 1) + val rangeExecInOnePartition = rangeInOnePartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInOnePartition.head.outputPartitioning == SinglePartition) + + val rangeInZeroPartition = spark.range(-10, -9, -20, 1) + val rangeExecInZeroPartition = rangeInZeroPartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..b714dcd5269fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,12 +51,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = spark.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index adcaf2d76519f..8251ff159e05f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -33,14 +34,16 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } test("debugCodegenStringSeq") { - val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) From e480eccd9754b4900c3e2c2036d69130a262cffe Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 21 May 2018 15:42:04 +0800 Subject: [PATCH 529/593] [SPARK-24323][SQL] Fix lint-java errors ## What changes were proposed in this pull request? This PR fixes the following errors reported by `lint-java` ``` % dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java:[39] (sizes) LineLength: Line is longer than 100 characters (found 104). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:[26] (sizes) LineLength: Line is longer than 100 characters (found 110). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:[30] (sizes) LineLength: Line is longer than 100 characters (found 104). ``` ## How was this patch tested? Run `lint-java` manually. Author: Kazuaki Ishizaki Closes #21374 from kiszk/SPARK-24323. --- .../spark/sql/sources/v2/reader/InputPartition.java | 4 ++-- .../spark/sql/sources/v2/reader/InputPartitionReader.java | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index f53687e113ae0..f2038d0de3ffe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -36,8 +36,8 @@ public interface InputPartition extends Serializable { /** - * The preferred locations where the input partition reader returned by this partition can run faster, - * but Spark does not guarantee to run the input partition reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run + * faster, but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index f0d808536207a..33fa7be4c1b20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,12 +23,12 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for - * outputting data for a RDD partition. + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is + * responsible for outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition - * readers that mix in {@link SupportsScanUnsafeRow}. + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input + * partition readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { From a6e883feb3b78232ad5cf636f7f7d5e825183041 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 21 May 2018 23:14:03 +0900 Subject: [PATCH 530/593] [SPARK-23935][SQL] Adding map_entries function ## What changes were proposed in this pull request? This PR adds `map_entries` function that returns an unordered array of all entries in the given map. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionsSuite` ## CodeGen examples ### Primitive types ``` val df = Seq(Map(1 -> 5, 2 -> 6)).toDF("m") df.filter('m.isNotNull).select(map_entries('m)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ /* 044 */ ArrayData project_value_0 = null; /* 045 */ /* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements(); /* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray(); /* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray(); /* 049 */ /* 050 */ final long project_size_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 051 */ project_numElements_0, /* 052 */ 32); /* 053 */ if (project_size_0 > 2147483632) { /* 054 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0]; /* 055 */ for (int z = 0; z < project_numElements_0; z++) { /* 056 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getInt(z), project_values_0.getInt(z)}); /* 057 */ } /* 058 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0); /* 059 */ /* 060 */ } else { /* 061 */ final byte[] project_arrayBytes_0 = new byte[(int)project_size_0]; /* 062 */ UnsafeArrayData project_unsafeArrayData_0 = new UnsafeArrayData(); /* 063 */ Platform.putLong(project_arrayBytes_0, 16, project_numElements_0); /* 064 */ project_unsafeArrayData_0.pointTo(project_arrayBytes_0, 16, (int)project_size_0); /* 065 */ /* 066 */ final int project_structsOffset_0 = UnsafeArrayData.calculateHeaderPortionInBytes(project_numElements_0) + project_numElements_0 * 8; /* 067 */ UnsafeRow project_unsafeRow_0 = new UnsafeRow(2); /* 068 */ for (int z = 0; z < project_numElements_0; z++) { /* 069 */ long offset = project_structsOffset_0 + z * 24L; /* 070 */ project_unsafeArrayData_0.setLong(z, (offset << 32) + 24L); /* 071 */ project_unsafeRow_0.pointTo(project_arrayBytes_0, 16 + offset, 24); /* 072 */ project_unsafeRow_0.setInt(0, project_keys_0.getInt(z)); /* 073 */ project_unsafeRow_0.setInt(1, project_values_0.getInt(z)); /* 074 */ } /* 075 */ project_value_0 = project_unsafeArrayData_0; /* 076 */ /* 077 */ } ``` ### Non-primitive types ``` val df = Seq(Map("a" -> "foo", "b" -> null)).toDF("m") df.filter('m.isNotNull).select(map_entries('m)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ /* 044 */ ArrayData project_value_0 = null; /* 045 */ /* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements(); /* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray(); /* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray(); /* 049 */ /* 050 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0]; /* 051 */ for (int z = 0; z < project_numElements_0; z++) { /* 052 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getUTF8String(z), project_values_0.getUTF8String(z)}); /* 053 */ } /* 054 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0); ``` Author: Marek Novotny Closes #21236 from mn-mikke/feature/array-api-map_entries-to-master. --- python/pyspark/sql/functions.py | 20 +++ .../sql/catalyst/expressions/UnsafeRow.java | 2 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 153 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 23 +++ .../expressions/ExpressionEvalHelper.scala | 3 + .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 44 +++++ 9 files changed, 287 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8490081facc5a..fbc8a2d038f8f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2344,6 +2344,26 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 867c2d5eab53d..1134a8866dc13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -419,6 +419,7 @@ object FunctionRegistry { expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dda525294259..d382d9aace109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -764,6 +764,40 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c82db839438ed..8d763dca5243e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -154,6 +155,158 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns an unordered array of all entries in the given map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) + } + + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 + } + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) + } + + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val structSizeAsLong = structSize + "L" + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment + } + + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "map_entries" +} + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6ae1ac18c4dc4..71ff96bb722e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a22e9d4655e8c..c2a44e0d33b18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a8fe583b83bc..5ab9cb3fb86a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3492,6 +3492,13 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d08982a138bc5..df23e07e441a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + val dummyFilter = (c: Column) => c.isNotNull || c.isNull + + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) + checkAnswer( + spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + checkAnswer( + spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), From 03e90f65bfdad376400a4ae4df31a82c05ed4d4b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 May 2018 00:19:18 +0800 Subject: [PATCH 531/593] [SPARK-24250][SQL] support accessing SQLConf inside tasks re-submit https://github.com/apache/spark/pull/21299 which broke build. A few new commits are added to fix the SQLConf problem in `JsonSchemaInference.infer`, and prevent us to access `SQLConf` in DAGScheduler event loop thread. ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan Closes #21376 from cloud-fan/config. --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../org/apache/spark/util/EventLoop.scala | 3 +- .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 33 ++++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 54 +++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../datasources/json/JsonInferSchema.scala | 15 +++-- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 12 files changed, 239 insertions(+), 43 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..5f2d16d03165f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -206,7 +206,7 @@ class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..a2fb3c64844b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -95,7 +95,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -107,7 +109,22 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + if (Utils.isTesting && SparkContext.getActive.isDefined) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we prevent it from happening. + val schedulerEventLoopThread = + SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread + if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } + } + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1309,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1776,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,41 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2df81d09c58e7..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -643,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e7eed95a560a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -45,8 +45,9 @@ private[sql] object JsonInferSchema { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -66,9 +67,13 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } - } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) canonicalizeType(rootType) match { case Some(st: StructType) => st diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..3dd0712e02448 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadOnlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From a33dcf4a0bbe20dce6f1e1e6c2e1c3828291fb3d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 21 May 2018 12:58:05 -0700 Subject: [PATCH 532/593] [SPARK-24234][SS] Reader for continuous processing shuffle ## What changes were proposed in this pull request? Read RDD for continuous processing shuffle, as well as the initial RPC-based row receiver. https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii ## How was this patch tested? new unit tests Author: Jose Torres Closes #21337 from jose-torres/readerRddMaster. --- .../shuffle/ContinuousShuffleReadRDD.scala | 61 ++++++ .../shuffle/ContinuousShuffleReader.scala | 32 +++ .../shuffle/UnsafeRowReceiver.scala | 75 +++++++ .../shuffle/ContinuousShuffleReadSuite.scala | 184 ++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..270b1a5c28dee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. + */ +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..42631c90ebc55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..b8adbb743c6c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. + */ +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..b25e75b3b37a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverEpochMarker() + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(unsafeRow(part.index)), + ReceiverEpochMarker() + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRowThread = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } +} From ffaefe755e20cb94e27f07b233615a4bbb476679 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 21 May 2018 13:05:17 -0700 Subject: [PATCH 533/593] [SPARK-7132][ML] Add fit with validation set to spark.ml GBT ## What changes were proposed in this pull request? Add fit with validation set to spark.ml GBT ## How was this patch tested? Will add later. Author: WeichenXu Closes #21129 from WeichenXu123/gbt_fit_validation. --- .../ml/classification/GBTClassifier.scala | 38 ++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 5 +- .../spark/ml/param/shared/sharedParams.scala | 17 +++++++ .../spark/ml/regression/GBTRegressor.scala | 31 ++++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 41 +++++++++++----- .../classification/GBTClassifierSuite.scala | 46 ++++++++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 48 ++++++++++++++++++- project/MimaExcludes.scala | 13 ++++- 8 files changed, 213 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3fb6d1e4e4f3e..337133a2e2326 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. - val oldDataset: RDD[LabeledPoint] = + val convert2LabeledPoint = (dataset: Dataset[_]) => { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + @@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") ( s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } - val numFeatures = oldDataset.first().features.size + } + + val (trainDataset, validationDataset) = if (withValidation) { + ( + convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), + convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (convert2LabeledPoint(dataset), null) + } + + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) + } + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b9c3170cc3c28..7e08675f834da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), isValid = "(value: String) => " + - "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"), + ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " + + "each row is for training or for validation. False indicates training; true indicates " + + "validation.") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 282ea6ebcbf7f..5928a0749f738 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params { /** @group getParam */ final def getDistanceMeasure: String = $(distanceMeasure) } + +/** + * Trait for shared param validationIndicatorCol. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasValidationIndicatorCol extends Params { + + /** + * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.. + * @group param + */ + final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.") + + /** @group getParam */ + final def getValidationIndicatorCol: String = $(validationIndicatorCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index d7e054bf55ef6..eb8b3c001436a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val numFeatures = oldDataset.first().features.size + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + + val (trainDataset, validationDataset) = if (withValidation) { + ( + extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), + extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (extractLabeledPoints(dataset), null) + } + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index ec8868bb42cbb..00157fe63af41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.util.Try +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize + with HasValidationIndicatorCol { + + /** + * Threshold for stopping early when fit with validation is used. + * (This parameter is ignored when fit without validation is used.) + * The decision to stop early is decided based on this logic: + * If the current loss on the validation set is greater than 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is less than or equal to 0.01, + * the diff of validation error is compared to absolute tolerance which is + * validationTol * 0.01. * @group param + * @see validationIndicatorCol */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 + @Since("2.4.0") + final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", + "Threshold for stopping early when fit with validation is used." + + "If the error rate on the validation input changes by less than the validationTol," + + "then learning will stop early (before `maxIter`)." + + "This parameter is ignored when fit without validation is used.", + ParamValidators.gtEq(0.0) + ) + + /** @group getParam */ + @Since("2.4.0") + final def getValidationTol: Double = $(validationTol) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") @@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol) } /** Get old Gradient Boosting Loss type */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e20de196d65ca..e6d2a8e2b900e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -392,6 +393,51 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(evalArr(2) ~== lossErr3 relTol 1E-3) } + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTClassifier.supportedLossTypes) { + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val (errorWithoutValidation, errorWithValidation) = { + val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), + GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) + } + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Classification) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 773f6d2c542fe..b145c7a3dc952 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -231,7 +232,52 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } - ///////////////////////////////////////////////////////////////////////////// + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + modelWithoutValidation.trees, modelWithoutValidation.treeWeights, + modelWithoutValidation.getOldLossType) + val errorWithValidation = GradientBoostedTrees.computeError(validationData, + modelWithValidation.trees, modelWithValidation.treeWeights, + modelWithValidation.getOldLossType) + + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Regression) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7d0e88ee20c3f..6bae4d147d4ac 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -73,7 +73,18 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), + + // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") ) // Exclude rules for 2.3.x From b550b2a1a159941c7327973182f16004a6bf179d Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 21 May 2018 14:21:05 -0700 Subject: [PATCH 534/593] [SPARK-24325] Tests for Hadoop's LinesReader ## What changes were proposed in this pull request? The tests cover basic functionality of [Hadoop LinesReader](https://github.com/apache/spark/blob/8d79113b812a91073d2c24a3a9ad94cc3b90b24a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala#L42). In particular, the added tests check: - A split slices a line or delimiter - A split slices two consecutive lines and cover a delimiter between the lines - Two splits slice a line and there are no duplicates - Internal buffer size (`io.file.buffer.size`) is less than line length - Constrain of maximum line length - `mapreduce.input.linerecordreader.line.maxlength` Author: Maxim Gekk Closes #21377 from MaxGekk/line-reader-tests. --- .../HadoopFileLinesReaderSuite.scala | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala new file mode 100644 index 0000000000000..a39a25be262a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFileLinesReaderSuite extends SharedSQLContext { + def getLines( + path: File, + text: String, + ranges: Seq[(Long, Long)], + delimiter: Option[String] = None, + conf: Option[Configuration] = None): Seq[String] = { + val delimOpt = delimiter.map(_.getBytes(StandardCharsets.UTF_8)) + Files.write(path.toPath, text.getBytes(StandardCharsets.UTF_8)) + + val lines = ranges.map { case (start, length) => + val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) + val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) + + reader.map(_.toString) + }.flatten + + lines + } + + test("A split ends at the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 1), (1, 3))) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 2), (2, 2))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the end of the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 3), (3, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split covers two lines") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 4), (4, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 1), (1, 4)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split slices the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 2), (2, 3)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("The first split covers the first line and the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 4), (4, 1)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the first line") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((0, 1)), Some(",")) + assert(lines == Seq("abc")) + } + } + + test("The split cuts both lines") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((2, 2)), Some(",")) + assert(lines == Seq("def")) + } + } + + test("io.file.buffer.size is less than line length") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("io.file.buffer.size", "2") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } + } + + test("line cannot be longer than line.maxlength") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } + } + + test("default delimiter is 0xd or 0xa or 0xd0xa") { + withTempPath { path => + val lines = getLines(path, text = "1\r2\n3\r\n4", ranges = Seq((0, 3), (3, 5))) + assert(lines == Seq("1", "2", "3", "4")) + } + } +} From 32447079e9d0fa9f7e180b94ecac19091b6af1ab Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 May 2018 16:26:39 -0700 Subject: [PATCH 535/593] [SPARK-24309][CORE] AsyncEventQueue should stop on interrupt. EventListeners can interrupt the event queue thread. In particular, when the EventLoggingListener writes to hdfs, hdfs can interrupt the thread. When there is an interrupt, the queue should be removed and stop accepting any more events. Before this change, the queue would continue to take more events (till it was full), and then would not stop when the application was complete because the PoisonPill couldn't be added. Added a unit test which failed before this change. Author: Imran Rashid Closes #21356 from squito/SPARK-24309. --- .../spark/scheduler/AsyncEventQueue.scala | 41 ++++++++------ .../spark/scheduler/LiveListenerBus.scala | 2 +- .../org/apache/spark/util/ListenerBus.scala | 18 +++++++ .../spark/scheduler/SparkListenerSuite.scala | 54 +++++++++++++++++++ 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a90..e2b6df4600590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08ad..d135190d1e919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f23..d4474a90b26f1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index fa47a52bbbc47..6ffd1e84f7adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -489,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -547,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want From 84d31aa5d453620d462f1fdd90206c676a8395cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 21 May 2018 18:11:05 -0700 Subject: [PATCH 536/593] [SPARK-24209][SHS] Automatic retrieve proxyBase from Knox headers ## What changes were proposed in this pull request? The PR retrieves the proxyBase automatically from the header `X-Forwarded-Context` (if available). This is the header used by Knox to inform the proxied service about the base path. This provides 0-configuration support for Knox gateway (instead of having to properly set `spark.ui.proxyBase`) and it allows to access directly SHS when it is proxied by Knox. In the previous scenario, indeed, after setting `spark.ui.proxyBase`, direct access to SHS was not working fine (due to bad link generated). ## How was this patch tested? added UT + manual tests Author: Marco Gaido Closes #21268 from mgaido91/SPARK-24209. --- .../spark/deploy/history/HistoryPage.scala | 17 +-- .../spark/deploy/history/HistoryServer.scala | 2 +- .../deploy/master/ui/ApplicationPage.scala | 4 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../spark/deploy/worker/ui/LogPage.scala | 2 +- .../spark/deploy/worker/ui/WorkerPage.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 109 ++++++++++-------- .../apache/spark/ui/env/EnvironmentPage.scala | 2 +- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../apache/spark/ui/exec/ExecutorsTab.scala | 6 +- .../apache/spark/ui/jobs/AllJobsPage.scala | 4 +- .../apache/spark/ui/jobs/AllStagesPage.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 5 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 4 +- .../org/apache/spark/ui/jobs/PoolTable.scala | 9 +- .../org/apache/spark/ui/jobs/StagePage.scala | 8 +- .../org/apache/spark/ui/jobs/StageTable.scala | 12 +- .../org/apache/spark/ui/storage/RDDPage.scala | 7 +- .../apache/spark/ui/storage/StoragePage.scala | 20 +++- .../deploy/history/HistoryServerSuite.scala | 24 ++++ .../spark/ui/storage/StoragePageSuite.scala | 7 +- .../spark/deploy/mesos/ui/DriverPage.scala | 6 +- .../deploy/mesos/ui/MesosClusterPage.scala | 2 +- .../sql/execution/ui/AllExecutionsPage.scala | 38 +++--- .../sql/execution/ui/ExecutionPage.scala | 30 ++--- .../thriftserver/ui/ThriftServerPage.scala | 17 +-- .../ui/ThriftServerSessionPage.scala | 9 +- .../apache/spark/streaming/ui/BatchPage.scala | 21 +++- .../spark/streaming/ui/StreamingPage.scala | 12 +- 29 files changed, 232 insertions(+), 155 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6fc12d721e6f1..32667ddf5c7ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - ++ - + ++ +
    - UIUtils.basicSparkPage(content, "History Server", true) + UIUtils.basicSparkPage(request, content, "History Server", true) } - private def makePageLink(showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) + private def makePageLink(request: HttpServletRequest, showIncomplete: Boolean): String = { + UIUtils.prependBaseUri(request, "/?" + "showIncomplete=" + showIncomplete) } private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 611fa563a7cd9..a9a4d5a4ec6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -87,7 +87,7 @@ class HistoryServer( if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { val msg =
    Application {appId} not found.
    res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n => res.getWriter().write(n.toString) } return diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f699c75085fe1..fad4e46dc035d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
    No running application with ID {appId}
    - return UIUtils.basicSparkPage(msg, "Not Found") + return UIUtils.basicSparkPage(request, msg, "Not Found") } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") @@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } ; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name) } private def executorRow(executor: ExecutorDesc): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c629937606b51..b8afe203fbfa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } private def workerRow(worker: WorkerInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 2f5a5642d3cab..4fca9342c0378 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with - UIUtils.basicSparkPage(content, logType + " log page for " + pageName) + UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8b98ae56fc108..aa4e28d213e2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 02cf19e00ecde..5d015b0531ef6 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} +import javax.servlet.http.HttpServletRequest import scala.util.control.NonFatal import scala.xml._ @@ -148,60 +149,71 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - def uiRoot: String = { + def uiRoot(request: HttpServletRequest): String = { + // Knox uses X-Forwarded-Context to notify the application the base path + val knoxBasePath = Option(request.getHeader("X-Forwarded-Context")) // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. sys.props.get("spark.ui.proxyBase") .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .orElse(knoxBasePath) .getOrElse("") } - def prependBaseUri(basePath: String = "", resource: String = ""): String = { - uiRoot + basePath + resource + def prependBaseUri( + request: HttpServletRequest, + basePath: String = "", + resource: String = ""): String = { + uiRoot(request) + basePath + resource } - def commonHeaderNodes: Seq[Node] = { + def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = { - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + } - def vizHeaderNodes: Seq[Node] = { - - - - - + def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + + + + } - def dataTablesHeaderNodes: Seq[Node] = { + def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> - - - - - - - + href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/> + + + + + + } /** Returns a spark page with correctly formatted headers */ def headerSparkPage( + request: HttpServletRequest, title: String, content: => Seq[Node], activeTab: SparkUITab, @@ -214,25 +226,26 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) - {commonHeaderNodes} - {if (showVisualization) vizHeaderNodes else Seq.empty} - {if (useDataTables) dataTablesHeaderNodes else Seq.empty} - + {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} + {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} + {appName} - {title}
    }.getOrElse(Text("Error fetching thread dump")) - UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + UIUtils.headerSparkPage(request, s"Thread dump for executor $executorId", content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 843486f4a70d2..d5a60f52cbb0f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -49,12 +49,12 @@ private[ui] class ExecutorsPage(
    {
    ++ - ++ - ++ + ++ + ++ }
    - UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) + UIUtils.headerSparkPage(request, "Executors", content, parent, useDataTables = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2b0f4acbac72a..f651fe97c2cd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -248,7 +248,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs, tableHeaderId, jobTag, - UIUtils.prependBaseUri(parent.basePath), + UIUtils.prependBaseUri(request, parent.basePath), "jobs", // subPath parameterOtherTable, killEnabled, @@ -407,7 +407,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + " Click on a job to see information about the stages of tasks inside it." - UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + UIUtils.headerSparkPage(request, "Spark Jobs", content, parent, helpText = Some(helpText)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 4658aa1cea3f1..f672ce0ec6a68 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -66,7 +66,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { ++
    - {poolTable.toNodeSeq} + {poolTable.toNodeSeq(request)}
    } else { Seq.empty[Node] @@ -74,7 +74,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val content = summary ++ poolsDescription ++ tables.flatten.flatten - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + UIUtils.headerSparkPage(request, "Stages for All Jobs", content, parent) } private def summaryAndTableForStatus( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 46f2a76cc651b..55444a2c0c9ab 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -195,7 +195,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP

    No information to display for job {jobId}

    return UIUtils.headerSparkPage( - s"Details for Job $jobId", content, parent) + request, s"Details for Job $jobId", content, parent) } val isComplete = jobData.status != JobExecutionStatus.RUNNING val stages = jobData.stageIds.map { stageId => @@ -413,6 +413,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP {failedStagesTable.toNodeSeq} } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + UIUtils.headerSparkPage( + request, s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index a3e1f13782e30..22a40101e33df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -49,7 +49,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { "stages/pool", parent.isFairScheduler, parent.killEnabled, false) val poolTable = new PoolTable(Map(pool -> uiPool), parent) - var content =

    Summary

    ++ poolTable.toNodeSeq + var content =

    Summary

    ++ poolTable.toNodeSeq(request) if (activeStages.nonEmpty) { content ++= } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + UIUtils.headerSparkPage(request, "Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 5dfce858dec07..96b5f72393070 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder +import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -28,7 +29,7 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = { @@ -39,15 +40,15 @@ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab - {pools.map { case (s, p) => poolRow(s, p) }} + {pools.map { case (s, p) => poolRow(request, s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + private def poolRow(request: HttpServletRequest, s: Schedulable, p: PoolData): Seq[Node] = { val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) + .format(UIUtils.prependBaseUri(request, parent.basePath), URLEncoder.encode(p.name, "UTF-8"))
    {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ac83de10f9237..2575914121c39 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,7 +112,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) @@ -125,7 +125,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( stageData, - UIUtils.prependBaseUri(parent.basePath) + + UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, @@ -498,7 +498,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {taskTableHTML ++ jsForScrollingDownToTaskTable}
    - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) } def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 18a4926f2f6c0..b8b20db1fa407 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -92,7 +92,8 @@ private[ui] class StageTableBase( stageSortColumn, stageSortDesc, isFailedStage, - parameterOtherTable + parameterOtherTable, + request ).table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -147,7 +148,8 @@ private[ui] class StagePagedTable( sortColumn: String, desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] { + parameterOtherTable: Iterable[String], + request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" @@ -161,7 +163,7 @@ private[ui] class StagePagedTable( override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + + val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( @@ -288,7 +290,7 @@ private[ui] class StagePagedTable( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(request, basePath), data.schedulingPool)}> {data.schedulingPool} @@ -346,7 +348,7 @@ private[ui] class StagePagedTable( } private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { - val basePathUri = UIUtils.prependBaseUri(basePath) + val basePathUri = UIUtils.prependBaseUri(request, basePath) val killLink = if (killEnabled) { val confirm = diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2674b9291203a..238cd31433660 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -53,7 +53,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } catch { case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) + return UIUtils.headerSparkPage(request, "RDD Not Found", Seq.empty[Node], parent) } // Worker table @@ -72,7 +72,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } val blockTableHTML = try { val _blockTable = new BlockPagedTable( - UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, @@ -145,7 +145,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web {blockTableHTML ++ jsForScrollingDownToBlockTable} ; - UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) + UIUtils.headerSparkPage( + request, "RDD Storage Info for " + rddStorageInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 68d946574a37b..3eb546e336e99 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,11 +31,14 @@ import org.apache.spark.util.Utils private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) - UIUtils.headerSparkPage("Storage", content, parent) + val content = rddTable(request, store.rddList()) ++ + receiverBlockTables(store.streamBlocksList()) + UIUtils.headerSparkPage(request, "Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { + private[storage] def rddTable( + request: HttpServletRequest, + rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -49,7 +52,11 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends
    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + {UIUtils.listingTable( + rddHeader, + rddRow(request, _: v1.RDDStorageInfo), + rdds, + id = Some("storage-by-rdd-table"))}
    } @@ -66,12 +73,13 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { + private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} - + {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a871b1c717837..11b29121739a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito._ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} @@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } + test("automatically retrieve uiRoot from request through Knox") { + assert(sys.props.get("spark.ui.proxyBase").isEmpty, + "spark.ui.proxyBase is defined but it should not for this UT") + assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty, + "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT") + val page = new HistoryPage(server) + val requestThroughKnox = mock[HttpServletRequest] + val knoxBaseUrl = "/gateway/default/sparkhistoryui" + when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl) + val responseThroughKnox = page.render(requestThroughKnox) + + val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString) + val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/")) + all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl) + + val directRequest = mock[HttpServletRequest] + val directResponse = page.render(directRequest) + + val directUrls = directResponse \\ "@href" map (_.toString) + val directSiteRelativeLinks = directUrls filter (_.startsWith("/")) + all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) + } + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index a71521c91d2f2..cdc7f541b9552 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.storage +import javax.servlet.http.HttpServletRequest + import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") val storagePage = new StoragePage(storageTab, null) + val request = mock(classOf[HttpServletRequest]) test("rddTable") { val rdd1 = new RDDStorageInfo(1, @@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite { None, None) - val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3)) val headers = Seq( "ID", @@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite { } test("empty rddTable") { - assert(storagePage.rddTable(Seq.empty).isEmpty) + assert(storagePage.rddTable(request, Seq.empty).isEmpty) } test("streamBlockStorageLevelDescriptionAndSize") { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 022191d0070fd..91f64141e5318 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")

    Cannot find driver {driverId}

    - return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } val driverState = state.get val driverHeaders = Seq("Driver property", "Value") @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    @@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
    ; - UIUtils.basicSparkPage(content, s"Details for Job $driverId") + UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 88a6614d51384..c53285331ea68 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {retryTable} ; - UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster") } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 582528777f90e..bf46bc4cf904d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L _content ++= new RunningExecutionTable( parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq + running.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq + completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (failed.nonEmpty) { _content ++= new FailedExecutionTable( parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq + failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } _content } @@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } - UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } } @@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable( protected def header: Seq[String] - protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + protected def row( + request: HttpServletRequest, + currentTime: Long, + executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - submissionTime @@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable( def jobLinks(status: JobExecutionStatus): Seq[Node] = { executionUIData.jobs.flatMap { case (jobId, jobStatus) => if (jobStatus == status) { - [{jobId.toString}] + [{jobId.toString}] } else { None } @@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(executionUIData)} + {descriptionCell(request, executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable( } - private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell( + request: HttpServletRequest, + execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { +details @@ -192,27 +197,28 @@ private[ui] abstract class ExecutionTable( } val desc = if (execution.description != null && execution.description.nonEmpty) { - {execution.description} + {execution.description} } else { - {execution.executionId} + {execution.executionId} }
    {desc} {details}
    } - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = {

    {tableName}

    {UIUtils.listingTable[SQLExecutionUIData]( - header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))}
    } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) - private def executionURL(executionID: Long): String = - s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" + private def executionURL(request: HttpServletRequest, executionID: Long): String = + s"${UIUtils.prependBaseUri( + request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e0554f0c4d337..282f7b4bb5a58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -49,7 +49,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
  • {label} {jobs.toSeq.sorted.map { jobId => - {jobId.toString}  + {jobId.toString}  }}
  • } else { @@ -77,27 +77,31 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, graph) ++ + planVisualization(request, metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for query {executionId}
    } - UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + UIUtils.headerSparkPage( + request, s"Details for Query $executionId", content, parent, Some(5000)) } - private def planVisualizationResources: Seq[Node] = { + private def planVisualizationResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - - - + + + + + // scalastyle:on } - private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization( + request: HttpServletRequest, + metrics: Map[Long, String], + graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
    {node.desc}
    @@ -112,13 +116,13 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    {graph.allNodes.size.toString}
    {metadata} - {planVisualizationResources} + {planVisualizationResources(request)} } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index f517bffccdf31..0950b30126773 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,10 +47,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {listener.getOnlineSessionNum} session(s) are online, running {listener.getTotalRunning} SQL statement(s) ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + generateSessionStatsTable(request) ++ + generateSQLStatsTable(request) } - UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -67,7 +67,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", @@ -76,7 +76,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } @@ -138,7 +139,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { + private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { @@ -146,8 +147,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + val sessionLink = "%s/%s/session?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 5cd2fdf6437c2..c884aa0ecbdf8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -56,9 +56,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Session created at {formatDate(sessionStat.startTimestamp)}, Total run {sessionStat.totalExecution} SQL ++ - generateSQLStatsTable(sessionStat.sessionId) + generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -75,7 +75,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { val executionList = listener.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size @@ -86,7 +86,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 6748dd4ec48e3..ca9da6139649a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,6 +47,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -54,7 +55,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { isFirstRow: Boolean, jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { if (jobIdWithData.jobData.isDefined) { - generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(request, outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, @@ -89,6 +90,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -106,7 +108,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") - val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + val detailUrl = s"${SparkUIUtils.prependBaseUri( + request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". @@ -196,6 +199,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = @@ -212,6 +216,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { val firstRow = generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -221,6 +226,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -278,7 +284,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate the job table for the batch. */ - private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + private def generateJobTable( + request: HttpServletRequest, + batchUIData: BatchUIData): Seq[Node] = { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId @@ -301,7 +309,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { outputOpWithJobs.map { case (outputOpData, sparkJobs) => - generateOutputOpIdRow(outputOpData, sparkJobs) + generateOutputOpIdRow(request, outputOpData, sparkJobs) } } @@ -364,9 +372,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
    - val content = summary ++ generateJobTable(batchUIData) + val content = summary ++ generateJobTable(request, batchUIData) - SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + SparkUIUtils.headerSparkPage( + request, s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 3a176f64cdd60..4ce661bc1144e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -148,7 +148,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val resources = generateLoadResources() + val resources = generateLoadResources(request) val basicInfo = generateBasicInfo() val content = resources ++ basicInfo ++ @@ -156,17 +156,17 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) } /** * Generate html that will load css/js files for StreamingPage */ - private def generateLoadResources(): Seq[Node] = { + private def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - + + + // scalastyle:on } From 952e4d1c830c4eb3dfd522be3d292dd02d8c9065 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 22 May 2018 19:12:30 +0800 Subject: [PATCH 537/593] [SPARK-24321][SQL] Extract common code from Divide/Remainder to a base trait ## What changes were proposed in this pull request? Extract common code from `Divide`/`Remainder` to a new base trait, `DivModLike`. Further refactoring to make `Pmod` work with `DivModLike` is to be done as a separate task. ## How was this patch tested? Existing tests in `ArithmeticExpressionSuite` covers the functionality. Author: Kris Mok Closes #21367 from rednaxelafx/catalyst-divmod. --- .../sql/catalyst/expressions/arithmetic.scala | 145 ++++++------------ 1 file changed, 51 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..efd4e992c8eec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -220,30 +220,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +234,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,7 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" @@ -283,7 +267,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { ev.copy(code = s""" @@ -297,13 +281,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +322,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( From 82fb5bfa770b0325d4f377dd38d89869007c6111 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 22 May 2018 21:02:17 +0800 Subject: [PATCH 538/593] [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ## What changes were proposed in this pull request? The ultimate goal is for listeners to onTaskEnd to receive metrics when a task is killed intentionally, since the data is currently just thrown away. This is already done for ExceptionFailure, so this just copies the same approach. ## How was this patch tested? Updated existing tests. This is a rework of https://github.com/apache/spark/pull/17422, all credits should go to noodle-fb Author: Xianjin YE Author: Charles Lewis Closes #21165 from advancedxy/SPARK-20087. --- .../org/apache/spark/TaskEndReason.scala | 8 ++- .../org/apache/spark/executor/Executor.scala | 55 ++++++++++++------- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/TaskSetManager.scala | 8 ++- .../org/apache/spark/util/JsonProtocol.scala | 9 ++- .../spark/scheduler/DAGSchedulerSuite.scala | 18 ++++-- project/MimaExcludes.scala | 5 ++ 7 files changed, 78 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a76283e33fa65..33901bc8380e9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case class TaskKilled(reason: String) extends TaskFailedReason { +case class TaskKilled( + reason: String, + accumUpdates: Seq[AccumulableInfo] = Seq.empty, + private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil) + extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false + } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c325222b764b8..b1856ff0f3247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -287,6 +287,28 @@ private[spark] class Executor( notifyAll() } + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) @@ -300,7 +322,7 @@ private[spark] class Executor( val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 + var taskStartTime: Long = 0 var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() @@ -336,7 +358,7 @@ private[spark] class Executor( } // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() + taskStartTime = System.currentTimeMillis() taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L @@ -396,11 +418,11 @@ private[spark] class Executor( // Deserialization happens in two parts: first, we deserialize a Task object, which // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) + (taskStartTime - deserializeStartTime) + task.executorDeserializeTime) task.metrics.setExecutorDeserializeCpuTime( (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting - task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime) task.metrics.setExecutorCpuTime( (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -482,16 +504,19 @@ private[spark] class Executor( } catch { case t: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -524,17 +549,7 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of premption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } - - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5f2d16d03165f..ea7bfd7d7a68d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1210,7 +1210,7 @@ class DAGScheduler( case _ => updateAccumulators(event) } - case _: ExceptionFailure => updateAccumulators(event) + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) case _ => } postTaskEnd(event) @@ -1414,13 +1414,13 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case exceptionFailure: ExceptionFailure => + case _: ExceptionFailure | _: TaskKilled => // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => + case _: ExecutorLostFailure | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 195fc8025e4b5..a18c66596852a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -851,13 +851,19 @@ private[spark] class TaskSetManager( } ef.exception + case tk: TaskKilled => + // TaskKilled might have accumulator updates + accumUpdates = tk.accums + logWarning(failureReason) + None + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None - case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others + case e: TaskFailedReason => // TaskResultLost and others logWarning(failureReason) None } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40383fe05026b..50c6461373dee 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -407,7 +407,9 @@ private[spark] object JsonProtocol { ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => - ("Kill Reason" -> taskKilled.reason) + val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) + ("Kill Reason" -> taskKilled.reason) ~ + ("Accumulator Updates" -> accumUpdates) case _ => emptyJson } ("Reason" -> reason) ~ json @@ -917,7 +919,10 @@ private[spark] object JsonProtocol { case `taskKilled` => val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") - TaskKilled(killReason) + val accumUpdates = jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(Seq[AccumulableInfo]()) + TaskKilled(killReason, accumUpdates) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b6ec37625eec..2987170bf5026 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("accumulators are updated on exception failures") { + test("accumulators are updated on exception failures and task killed") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") val acc3 = AccumulatorSuite.createLongAccum("agriculteur") @@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accUpdate3 = new LongAccumulator accUpdate3.metadata = acc3.metadata accUpdate3.setValue(18) - val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3) - val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo) + + val accumUpdates1 = Seq(accUpdate1, accUpdate2) + val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo) val exceptionFailure = new ExceptionFailure( new SparkException("fondue?"), - accumInfo).copy(accums = accumUpdates) + accumInfo1).copy(accums = accumUpdates1) submit(new MyRDD(sc, 1, Nil), Array(0)) runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(AccumulatorContext.get(acc1.id).get.value === 15L) assert(AccumulatorContext.get(acc2.id).get.value === 13L) + + val accumUpdates2 = Seq(accUpdate3) + val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo) + + val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result")) + assert(AccumulatorContext.get(acc3.id).get.value === 18L) } @@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accumUpdates = reason match { case Success => task.metrics.accumulators() case ef: ExceptionFailure => ef.accums + case tk: TaskKilled => tk.accums case _ => Seq.empty } CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6bae4d147d4ac..4f6d5ff898681 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), From a4470bc78ca5f5a090b6831a7cdca88274eb9afc Mon Sep 17 00:00:00 2001 From: Jake Charland Date: Tue, 22 May 2018 08:06:15 -0500 Subject: [PATCH 539/593] [SPARK-21673] Use the correct sandbox environment variable set by Mesos ## What changes were proposed in this pull request? This change changes spark behavior to use the correct environment variable set by Mesos in the container on startup. Author: Jake Charland Closes #18894 from jakecharland/MesosSandbox. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 8 ++++---- docs/configuration.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 13adaa921dc23..f9191a59c1655 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -810,15 +810,15 @@ private[spark] object Utils extends Logging { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { conf.getenv("SPARK_LOCAL_DIRS").split(",") - } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) { // Mesos already creates a directory per Mesos task. Spark should use that directory // instead so all temporary files are automatically cleaned up when the Mesos task ends. // Note that we don't want this if the shuffle service is enabled because we want to // continue to serve shuffle files after the executors that wrote them have already exited. - Array(conf.getenv("MESOS_DIRECTORY")) + Array(conf.getenv("MESOS_SANDBOX")) } else { - if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { - logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) { + logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " + "spark.shuffle.service.enabled is enabled.") } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user diff --git a/docs/configuration.md b/docs/configuration.md index 8a1aacef85760..fd2670cba2125 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -208,7 +208,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. From d3d18073152cab4408464d1417ec644d939cfdf7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 22 May 2018 21:08:49 +0800 Subject: [PATCH 540/593] [SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types ## What changes were proposed in this pull request? The interpreted evaluation of several collection operations works only for simple datatypes. For complex data types, for instance, `array_contains` it returns always `false`. The list of the affected functions is `array_contains`, `array_position`, `element_at` and `GetMapValue`. The PR fixes the behavior for all the datatypes. ## How was this patch tested? added UT Author: Marco Gaido Closes #21361 from mgaido91/SPARK-24313. --- .../expressions/collectionOperations.scala | 41 ++++++++++++---- .../expressions/complexTypeExtractors.scala | 19 +++++-- .../CollectionExpressionsSuite.scala | 49 ++++++++++++++++++- .../optimizer/complexTypesSuite.scala | 13 +++++ .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++ 5 files changed, 113 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8d763dca5243e..7da4c3cc6b9fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -657,6 +657,9 @@ case class ArrayContains(left: Expression, right: Expression) override def dataType: DataType = BooleanType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq.empty case _ => left.dataType match { @@ -673,7 +676,7 @@ case class ArrayContains(left: Expression, right: Expression) TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -686,7 +689,7 @@ case class ArrayContains(left: Expression, right: Expression) arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == null) { hasNull = true - } else if (v == value) { + } else if (ordering.equiv(v, value)) { return true } ) @@ -735,11 +738,7 @@ case class ArraysOverlap(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - if (RowOrdering.isOrderable(elementType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") case failure => failure } @@ -1391,13 +1390,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { + if (v != null && ordering.equiv(v, value)) { return (i + 1).toLong } ) @@ -1446,6 +1456,9 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + override def dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType @@ -1460,6 +1473,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti ) } + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr( + left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + override def nullable: Boolean = true override def nullSafeEval(value: Any, ordinal: Any): Any = { @@ -1484,7 +1507,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d745453..99671d5b863c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 71ff96bb722e2..3fc0b08c56e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -157,6 +157,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + // binary + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + // complex data types + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) } test("ArraysOverlap") { @@ -372,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { @@ -409,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -420,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test binary type as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d495581..5452e72b38647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support binary type as map keys in GetMapValue") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee7504..1cc8cb3874c9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } } From fc743f7b30902bad1da36131087bb922c17a048e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 22 May 2018 08:20:59 -0500 Subject: [PATCH 541/593] [SPARK-20120][SQL][FOLLOW-UP] Better way to support spark-sql silent mode. ## What changes were proposed in this pull request? `spark-sql` silent mode will broken if`SPARK_HOME/jars` missing `kubernetes-model-2.0.0.jar`. This pr use `sc.setLogLevel ()` to implement silent mode. ## How was this patch tested? manual tests ``` build/sbt -Phive -Phive-thriftserver package export SPARK_PREPEND_CLASSES=true ./bin/spark-sql -S ``` Author: Yuming Wang Closes #20274 from wangyum/SPARK-20120-FOLLOW-UP. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 084f8200102ba..d9fd3ebd3c65d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.log4j.{Level, Logger} +import org.apache.log4j.Level import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf @@ -300,10 +300,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) - if (sessionState.getIsSilent) { - Logger.getRootLogger.setLevel(Level.WARN) - } - private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -315,6 +311,9 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } else { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") From 8086acc2f676a04ce6255a621ffae871bd09ceea Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 22 May 2018 22:07:32 +0800 Subject: [PATCH 542/593] [SPARK-24244][SQL] Passing only required columns to the CSV parser ## What changes were proposed in this pull request? uniVocity parser allows to specify only required column names or indexes for [parsing](https://www.univocity.com/pages/parsers-tutorial) like: ``` // Here we select only the columns by their indexes. // The parser just skips the values in other columns parserSettings.selectIndexes(4, 0, 1); CsvParser parser = new CsvParser(parserSettings); ``` In this PR, I propose to extract indexes from required schema and pass them into the CSV parser. Benchmarks show the following improvements in parsing of 1000 columns: ``` Select 100 columns out of 1000: x1.76 Select 1 column out of 1000: x2 ``` **Note**: Comparing to current implementation, the changes can return different result for malformed rows in the `DROPMALFORMED` and `FAILFAST` modes if only subset of all columns is requested. To have previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## How was this patch tested? It was tested by new test which selects 3 columns out of 15, by existing tests and by new benchmarks. Author: Maxim Gekk Closes #21296 from MaxGekk/csv-column-pruning. --- docs/sql-programming-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 7 +++ .../datasources/csv/CSVOptions.scala | 3 ++ .../datasources/csv/UnivocityParser.scala | 26 ++++++----- .../datasources/csv/CSVBenchmarks.scala | 42 ++++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 43 ++++++++++++++++--- 6 files changed, 104 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f1ed316341b95..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a2fb3c64844b5..d0478d6ad250b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1295,6 +1295,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1066d156acd74..dd41aee0f2ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -80,6 +81,8 @@ class CSVOptions( } } + private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index d442ba7e59c61..ec788df00aa92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,7 +74,49 @@ object CSVBenchmarks { } } + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X + Select 100 columns 28625 / 32884 0.0 28625.1 2.7X + Select one column 22498 / 22669 0.0 22497.8 3.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07e6c74b14d0d..5f9f799a6c466 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1368,4 +1370,31 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + checkAnswer( + idf, + List(Row(15, 10, 5), Row(-15, -10, -5)) + ) + } + } } From f9f055afa47412eec8228c843b34a90decb9be43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 01:50:22 +0800 Subject: [PATCH 543/593] [SPARK-24121][SQL] Add API for handling expression code generation ## What changes were proposed in this pull request? This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions. In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code. For example, in following java code: ```java int ${variable} = 1; boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; ``` `variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code. For codegen, we provide a specified string interpolator `code`, so you can define a code like this: ```scala val codeBlock = code""" |int ${variable} = 1; |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; """.stripMargin // Generates actual java code. codeBlock.toString ``` Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc.. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21193 from viirya/SPARK-24121. --- .../catalyst/expressions/BoundAttribute.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 +- .../sql/catalyst/expressions/Expression.scala | 26 ++-- .../MonotonicallyIncreasingID.scala | 3 +- .../sql/catalyst/expressions/ScalaUDF.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 3 +- .../expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/TimeWindow.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 13 +- .../expressions/codegen/CodeGenerator.scala | 25 +-- .../expressions/codegen/CodegenFallback.scala | 5 +- .../codegen/GenerateSafeProjection.scala | 7 +- .../codegen/GenerateUnsafeProjection.scala | 5 +- .../expressions/codegen/javaCode.scala | 145 +++++++++++++++++- .../expressions/collectionOperations.scala | 19 +-- .../expressions/complexTypeCreator.scala | 7 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 23 +-- .../expressions/decimalExpressions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../spark/sql/catalyst/expressions/hash.scala | 5 +- .../catalyst/expressions/inputFileBlock.scala | 14 +- .../expressions/mathExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 9 +- .../expressions/objects/objects.scala | 48 +++--- .../sql/catalyst/expressions/predicates.scala | 15 +- .../expressions/randomExpressions.scala | 5 +- .../expressions/regexpExpressions.scala | 9 +- .../expressions/stringExpressions.scala | 25 +-- .../ExpressionEvalHelperSuite.scala | 3 +- .../expressions/codegen/CodeBlockSuite.scala | 136 ++++++++++++++++ .../sql/execution/ColumnarBatchScan.scala | 9 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 15 +- .../aggregate/HashAggregateExec.scala | 7 +- .../aggregate/HashMapGenerator.scala | 3 +- .../joins/BroadcastHashJoinExec.scala | 3 +- .../execution/joins/SortMergeJoinExec.scala | 5 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 41 files changed, 479 insertions(+), 172 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..df3ab05e02c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..699ea53b5df0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + + ev.copy(code = + code""" + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..9b9fa41a47d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 9f0779642271d..f1da592a76845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..3e7ca88249737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..2ce9d072c71c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..84e38a8b2711e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index efd4e992c8eec..fe91e520169b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic { ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d382d9aace109..66315e5906253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -330,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -1056,7 +1057,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1084,7 +1085,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1141,7 +1142,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1160,9 +1161,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..8f2a5a0dce943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..250ce48d059e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -114,6 +115,147 @@ object JavaCode { } } +/** + * A trait representing a block of java code. + */ +trait Block extends JavaCode { + + // The expressions to be evaluated inside this block. + def exprValues: Set[ExprValue] + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def nonEmpty: Boolean = toString.nonEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + // Concatenates this block with other block. + def + (other: Block): Block +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: Block => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case _ => + buf.append(input) + } + buf.append(strings.next) + } + if (buf.nonEmpty) { + codeParts += buf.toString + } + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override lazy val exprValues: Set[ExprValue] = { + blockInputs.flatMap { + case b: Block => b.exprValues + case e: ExprValue => Set(e) + }.toSet + } + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + +case class Blocks(blocks: Seq[Block]) extends Block { + override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet + override lazy val code: String = blocks.map(_.toString).mkString("\n") + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(blocks :+ c) + case b: Blocks => Blocks(blocks ++ b.blocks) + case EmptyBlock => this + } +} + +object EmptyBlock extends Block with Serializable { + override val code: String = "" + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other +} + /** * A typed java fragment that must be a valid java expression. */ @@ -123,10 +265,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7da4c3cc6b9fa..c28eab71b84fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : @@ -1177,14 +1178,14 @@ case class ArrayJoin( } if (nullable) { ev.copy( - s""" + code""" |boolean ${ev.isNull} = true; |UTF8String ${ev.value} = null; |$code """.stripMargin) } else { ev.copy( - s""" + code""" |UTF8String ${ev.value} = null; |$code """.stripMargin, FalseLiteral) @@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression { expressions = inputs, funcName = "valueConcat", extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" $initCode $codes $javaType ${ev.value} = $concatenator.concat($args); @@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = false; |${leftGen.code} |${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..a9867aaeb0cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..77ac6c088022e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03422fecb3209..e8d85f72f7a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") val eval = child.genCode(ctx) - val code = s""" + val code = code""" |${eval.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; @@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..b7c52f1d7b40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..cec00b66f873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b7834696cafc3..5d98dac46cf17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..2eeed3bbb2d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f974fd81fc788..2bf4203d0fec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -269,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -385,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -492,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -532,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -935,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1709,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f8c6dc4e6adc9..f54103c4fbfba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2653b28f6c3bd..926c2f00d430d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..7b68bb771faf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..9823b2fc5ad97 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..d2c6420eadb20 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block interpolates string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.exprValues + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin.trim + + assert(code.toString == expected) + + val exprValues = code.exprValues + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) + } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = Tuple2(1, 1) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("replace expr values in code block") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { + case _: SimpleExprValue => aliasedParam + case other => other + } + val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..372dc3db36ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim + |${ev.code} + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + evaluateVars.append(ev.code.toString + "\n") + ev.code = EmptyBlock } } evaluateVars.toString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6a8ec4f722aea..8c7b2c187cccd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin @@ -773,8 +774,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..0da0e8610c392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } From bc6ea614ad4c6a323c78f209120287b256a458d3 Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Tue, 22 May 2018 13:01:07 -0700 Subject: [PATCH 544/593] [SPARK-24348][SQL] "element_at" error fix ## What changes were proposed in this pull request? ### Fixes a `scala.MatchError` in the `element_at` operation - [SPARK-24348](https://issues.apache.org/jira/browse/SPARK-24348) When calling `element_at` with a wrong first operand type an `AnalysisException` should be thrown instead of `scala.MatchError` *Example:* ```sql select element_at('foo', 1) ``` results in: ``` scala.MatchError: StringType (of class org.apache.spark.sql.types.StringType$) at org.apache.spark.sql.catalyst.expressions.ElementAt.inputTypes(collectionOperations.scala:1469) at org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes$class.checkInputDataTypes(ExpectsInputTypes.scala:44) at org.apache.spark.sql.catalyst.expressions.ElementAt.checkInputDataTypes(collectionOperations.scala:1478) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:168) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:168) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:256) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:252) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:288) ``` ## How was this patch tested? unit tests Author: Vayda, Oleksandr: IT (PRG) Closes #21395 from wajda/SPARK-24348-element_at-error-fix. --- .../sql/catalyst/expressions/collectionOperations.scala | 1 + .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c28eab71b84fd..03b3b21a16617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1470,6 +1470,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti left.dataType match { case _: ArrayType => IntegerType case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _ => AnyDataType // no match for a wrong 'left' expression type } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index df23e07e441a0..ec2a569f900d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -756,6 +756,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") + } + assert(e.message.contains( + "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } test("concat function - arrays") { From 79e06faa4ef6596c9e2d4be09c74b935064021bb Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 22 May 2018 13:43:45 -0700 Subject: [PATCH 545/593] [SPARK-19185][DSTREAMS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? `CachedKafkaConsumer` in the project streaming-kafka-0-10 is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one thread trying to read the same Kafka TopicPartition at the same time. This assumption is not true all the time and this can inadvertently lead to ConcurrentModificationException. Here is a better way to design this. The consumer pool should be smart enough to avoid concurrent use of a cached consumer. If there is another request for the same TopicPartition as a currently in-use consumer, the pool should automatically return a fresh consumer. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called `KafkaDataConsumer` is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply call `val consumer = KafkaDataConsumer.acquire` and then `consumer.release`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is request for a consumer which is a task reattempt, then already existing cached consumer will be invalidated and a new consumer is generated. This could fix potential issues if the source of the reattempt is a malfunctioning consumer. - In addition, I renamed the `CachedKafkaConsumer` class to `KafkaDataConsumer` because is a misnomer given that what it returns may or may not be cached. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same TopicPartition from the consumer pool. Author: Gabor Somogyi Closes #20997 from gaborgsomogyi/SPARK-19185. --- .../sql/kafka010/KafkaDataConsumer.scala | 2 +- .../kafka010/CachedKafkaConsumer.scala | 226 ----------- .../kafka010/KafkaDataConsumer.scala | 359 ++++++++++++++++++ .../spark/streaming/kafka010/KafkaRDD.scala | 20 +- .../kafka010/KafkaDataConsumerSuite.scala | 131 +++++++ 5 files changed, 496 insertions(+), 242 deletions(-) delete mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala create mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 48508d057a540..941f0ab177e48 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // likely running on a beefy machine that can handle a large number of simultaneously // active consumers. - if (entry.getValue.inUse == false && this.size > capacity) { + if (!entry.getValue.inUse && this.size > capacity) { logWarning( s"KafkaConsumer cache hitting max capacity of $capacity, " + s"removing consumer for ${entry.getKey}") diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala deleted file mode 100644 index aeb8c1dc342b3..0000000000000 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kafka010 - -import java.{ util => ju } - -import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } -import org.apache.kafka.common.{ KafkaException, TopicPartition } - -import org.apache.spark.internal.Logging - -/** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. - */ -private[kafka010] -class CachedKafkaConsumer[K, V] private( - val groupId: String, - val topic: String, - val partition: Int, - val kafkaParams: ju.Map[String, Object]) extends Logging { - - require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), - "groupId used for cache key must match the groupId in kafkaParams") - - val topicPartition = new TopicPartition(topic, partition) - - protected val consumer = { - val c = new KafkaConsumer[K, V](kafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - // TODO if the buffer was kept around as a random-access structure, - // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() - protected var nextOffset = -2L - - def close(): Unit = consumer.close() - - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. - */ - def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { - logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") - if (offset != nextOffset) { - logInfo(s"Initial fetch for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - - if (!buffer.hasNext()) { poll(timeout) } - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - var record = buffer.next() - - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - record = buffer.next() - require(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + - s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + - "spark.streaming.kafka.allowNonConsecutiveOffsets" - ) - } - - nextOffset = offset + 1 - record - } - - /** - * Start a batch on a compacted topic - */ - def compactedStart(offset: Long, timeout: Long): Unit = { - logDebug(s"compacted start $groupId $topic $partition starting $offset") - // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics - if (offset != nextOffset) { - logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - } - - /** - * Get the next record in the batch from a compacted topic. - * Assumes compactedStart has been called first, and ignores gaps. - */ - def compactedNext(timeout: Long): ConsumerRecord[K, V] = { - if (!buffer.hasNext()) { - poll(timeout) - } - require(buffer.hasNext(), - s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") - val record = buffer.next() - nextOffset = record.offset + 1 - record - } - - /** - * Rewind to previous record in the batch from a compacted topic. - * @throws NoSuchElementException if no previous element - */ - def compactedPrevious(): ConsumerRecord[K, V] = { - buffer.previous() - } - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - private def poll(timeout: Long): Unit = { - val p = consumer.poll(timeout) - val r = p.records(topicPartition) - logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.listIterator - } - -} - -private[kafka010] -object CachedKafkaConsumer extends Logging { - - private case class CacheKey(groupId: String, topic: String, partition: Int) - - // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap - private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null - - /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ - def init( - initialCapacity: Int, - maxCapacity: Int, - loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { - if (null == cache) { - logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") - cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( - initialCapacity, loadFactor, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { - if (this.size > maxCapacity) { - try { - entry.getValue.consumer.close() - } catch { - case x: KafkaException => - logError("Error closing oldest Kafka consumer", x) - } - true - } else { - false - } - } - } - } - } - - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def get[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) - val v = cache.get(k) - if (null == v) { - logInfo(s"Cache miss for $k") - logDebug(cache.keySet.toString) - val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - cache.put(k, c) - c - } else { - // any given topicpartition should have a consistent key and value type - v.asInstanceOf[CachedKafkaConsumer[K, V]] - } - } - - /** - * Get a fresh new instance, unassociated with the global cache. - * Caller is responsible for closing - */ - def getUncached[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - - /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) - logInfo(s"Removing $k from cache") - val v = CachedKafkaConsumer.synchronized { - cache.remove(k) - } - if (null != v) { - v.close() - logInfo(s"Removed $k from cache") - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala new file mode 100644 index 0000000000000..68c5fe9ab066a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging + +private[kafka010] sealed trait KafkaDataConsumer[K, V] { + /** + * Get the record for the given offset if available. + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.get(offset, pollTimeoutMs) + } + + /** + * Start a batch on a compacted topic + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + internalConsumer.compactedStart(offset, pollTimeoutMs) + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + * + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.compactedNext(pollTimeoutMs) + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + internalConsumer.compactedPrevious() + } + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + def internalConsumer: InternalKafkaConsumer[K, V] +} + + +/** + * A wrapper around Kafka's KafkaConsumer. + * This is not for direct use outside this file. + */ +private[kafka010] class InternalKafkaConsumer[K, V]( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + .asInstanceOf[String] + + private val consumer = createConsumer + + /** indicates whether this consumer is in use or not */ + var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + var markedForClose = false + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() + @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET + + override def toString: String = { + "InternalKafkaConsumer(" + + s"hash=${Integer.toHexString(hashCode)}, " + + s"groupId=$groupId, " + + s"topicPartition=$topicPartition)" + } + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[K, V] = { + val c = new KafkaConsumer[K, V](kafkaParams) + val topics = ju.Arrays.asList(topicPartition) + c.assign(topics) + c + } + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + record = buffer.next() + require(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) + } + + nextOffset = offset + 1 + record + } + + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + logDebug(s"compacted start $groupId $topicPartition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(pollTimeoutMs) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topicPartition " + + s"after polling for $pollTimeoutMs") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.listIterator + } + +} + +private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) + +private[kafka010] object KafkaDataConsumer extends Logging { + + private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + assert(internalConsumer.inUse) + override def release(): Unit = KafkaDataConsumer.release(internalConsumer) + } + + private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + override def release(): Unit = internalConsumer.close() + } + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null + + /** + * Must be called before acquire, once per JVM, to configure the cache. + * Further calls are ignored. + */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > maxCapacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $maxCapacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by anyone + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. + */ + def acquire[K, V]( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + context: TaskContext, + useCache: Boolean): KafkaDataConsumer[K, V] = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = cache.get(key) + + lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams) + + if (context != null && context.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumers if any and + // start with a new one. If prior attempt failures were cache related then this way old + // problematic consumers can be removed. + logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer") + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + // Remove the consumer from cache only if it's closed. + // Marked for close consumers will be removed in release function. + cache.remove(key) + } + } + + logDebug("Reattempt detected, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (!useCache) { + // If consumer reuse turned off, then do not use it, return a new consumer + logDebug("Cache usage turned off, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + logDebug("No cached consumer, new cached consumer will be allocated " + + s"$newInternalConsumer") + cache.put(key, newInternalConsumer) + CachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + logDebug("Used cached consumer found, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else { + // If consumer is already cached and is currently not in use, then return that consumer + logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer") + existingInternalConsumer.inUse = true + // Any given TopicPartition should have a consistent key and value type + CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]]) + } + } + + private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized { + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition) + val cachedInternalConsumer = cache.get(key) + if (internalConsumer.eq(cachedInternalConsumer)) { + // The released consumer is the same object as the cached one. + if (internalConsumer.markedForClose) { + internalConsumer.close() + cache.remove(key) + } else { + internalConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + internalConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache " + + s"$internalConsumer") + } + } +} + +private[kafka010] object InternalKafkaConsumer { + private val UNKNOWN_OFFSET = -2L +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 07239eda64d2e..81abc9860bfc3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } -import scala.collection.mutable.ArrayBuffer - import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } import org.apache.kafka.common.TopicPartition @@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - context.addTaskCompletionListener(_ => closeIfNeeded()) - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + val consumer = { + KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache) } var requestOffset = part.fromOffset def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close() + if (consumer != null) { + consumer.release() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..d934c64962adb --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll { + private var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + KafkaDataConsumer.init(16, 64, 0.75f) + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> groupId, + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer1.release() + + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer2.release() + + assert(KafkaDataConsumer.cache.size() == 1) + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = KafkaDataConsumer.cache.get(key) + assert(existingInternalConsumer.eq(consumer1.internalConsumer)) + assert(existingInternalConsumer.eq(consumer2.internalConsumer)) + } + + test("concurrent use of KafkaDataConsumer") { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic) + testUtils.sendMessages(topic, data.toArray) + + val kafkaParams = getKafkaParams() + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, taskContext, useCache) + try { + val rcvd = (0 until data.length).map { offset => + val bytes = consumer.get(offset, 10000).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadPool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadPool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadPool.shutdown() + } + } +} From 00c13cfad78607fde0787c9d494f0df8ab7051ba Mon Sep 17 00:00:00 2001 From: Seth Fitzsimmons Date: Wed, 23 May 2018 09:14:03 +0800 Subject: [PATCH 546/593] Correct reference to Offset class This is a documentation-only correction; `org.apache.spark.sql.sources.v2.reader.Offset` is actually `org.apache.spark.sql.sources.v2.reader.streaming.Offset`. Author: Seth Fitzsimmons Closes #21387 from mojodna/patch-1. --- .../org/apache/spark/sql/execution/streaming/Offset.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java index 80aa5505db991..43ad4b3384ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -19,8 +19,8 @@ /** * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported - * in the long term. + * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be + * supported in the long term. * * This class will be removed in a future release. */ From a40ffc656d62372da85e0fa932b67207839e7fde Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 22:40:52 +0800 Subject: [PATCH 547/593] [SPARK-23711][SQL] Add fallback generator for UnsafeProjection ## What changes were proposed in this pull request? Add fallback logic for `UnsafeProjection`. In production we can try to create unsafe projection using codegen implementation. Once any compile error happens, it fallbacks to interpreted implementation. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21106 from viirya/SPARK-23711. --- ...CodeGeneratorWithInterpretedFallback.scala | 73 +++++++++++++++++++ .../InterpretedUnsafeProjection.scala | 5 +- .../sql/catalyst/expressions/Projection.scala | 60 +++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 12 +++ ...eneratorWithInterpretedFallbackSuite.scala | 43 +++++++++++ .../expressions/ExpressionEvalHelper.scala | 52 ++++++------- .../expressions/ObjectExpressionsSuite.scala | 18 +---- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++----- .../UnsafeKVExternalSorterSuite.scala | 3 +- 9 files changed, 230 insertions(+), 88 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala new file mode 100644 index 0000000000000..fb25e781e72e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.InternalCompilerException + +import org.apache.spark.TaskContext +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Catches compile error during code generation. + */ +object CodegenError { + def unapply(throwable: Throwable): Option[Exception] = throwable match { + case e: InternalCompilerException => Some(e) + case e: CompileException => Some(e) + case _ => None + } +} + +/** + * Defines values for `SQLConf` config of fallback mode. Use for test only. + */ +object CodegenObjectFactoryMode extends Enumeration { + val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value +} + +/** + * A codegen object generator which creates objects with codegen path first. Once any compile + * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. + */ +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { + + def createObject(in: IN): OUT = { + // We are allowed to choose codegen-only or no-codegen modes if under tests. + val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) + val fallbackMode = CodegenObjectFactoryMode.withName(config) + + fallbackMode match { + case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting => + createCodeGeneratedObject(in) + case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting => + createInterpretedObject(in) + case _ => + try { + createCodeGeneratedObject(in) + } catch { + case CodegenError(_) => createInterpretedObject(in) + } + } + } + + protected def createCodeGeneratedObject(in: IN): OUT + protected def createInterpretedObject(in: IN): OUT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6d69d69b1c802..55a5bd380859e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** * Helper functions for creating an [[InterpretedUnsafeProjection]]. */ -object InterpretedUnsafeProjection extends UnsafeProjectionCreator { - +object InterpretedUnsafeProjection { /** * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + def createProjection(exprs: Seq[Expression]): UnsafeProjection = { // We need to make sure that we do not reuse stateful expressions. val cleanedExpressions = exprs.map(_.transform { case s: Stateful => s.freshCopy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 3cd73682188bc..6493f09100577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -trait UnsafeProjectionCreator { +/** + * The factory object for `UnsafeProjection`. + */ +object UnsafeProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + + protected def toBoundExprs( + exprs: Seq[Expression], + inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) + } + + protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_ transform { + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + } + /** * Returns an UnsafeProjection for given StructType. * @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator { * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - createProjection(unsafeExprs) + createObject(toUnsafeExprs(exprs)) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator { * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(exprs.map(BindReferences.bindReference(_, inputSchema))) - } - - /** - * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. - */ - protected def createProjection(exprs: Seq[Expression]): UnsafeProjection -} - -object UnsafeProjection extends UnsafeProjectionCreator { - - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + create(toBoundExprs(exprs, inputSchema)) } /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. + * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, + * when fallbacking to interpreted execution, it is not supported. */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) + try { + GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) + } catch { + case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d0478d6ad250b..fb98df587129b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.Utils @@ -703,6 +704,17 @@ object SQLConf { .intConf .createWithDefault(100) + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then fallbacking to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala new file mode 100644 index 0000000000000..531ca9a87370a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} + +class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + + test("UnsafeProjection with codegen factory mode") { + val input = Seq(LongType, IntegerType) + .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = UnsafeProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) + } + + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = UnsafeProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c2a44e0d33b18..14bfa212b5496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -205,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) - } - - protected def checkEvaluationWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow, - factory: UnsafeProjectionCreator): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") - } - } else { - val lit = InternalRow(expected, expected) - val expectedRow = - factory.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected, expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } } } } protected def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow, - factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + inputRow: InternalRow = EmptyRow): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 77ca640f2e0bd..20d568c44258f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -81,10 +81,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) checkEvaluationWithUnsafeProjection( - structEncoder.serializer.head, - structExpected, - structInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + structEncoder.serializer.head, structExpected, structInputRow) // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] @@ -92,10 +89,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) checkEvaluationWithUnsafeProjection( - arrayEncoder.serializer.head, - arrayExpected, - arrayInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) checkEvaluationWithUnsafeProjection( - mapEncoder.serializer.head, - mapExpected, - mapInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + mapEncoder.serializer.head, mapExpected, mapInputRow) } test("SPARK-23582: StaticInvoke should support interpreted execution") { @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithUnsafeProjection( expr, expected, - inputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + inputRow) } checkEvaluationWithOptimization(expr, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c07da122cd7b8..5a646d9a850ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,25 +24,30 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testWithFactory( - name: String)( - f: UnsafeProjectionCreator => Unit): Unit = { - test(name) { - f(UnsafeProjection) - f(InterpretedUnsafeProjection) + private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } } } - testWithFactory("basic conversion with only primitive types") { factory => + testBothCodegenAndInterpreted("basic conversion with only primitive types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) @@ -79,7 +84,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - testWithFactory("basic conversion with primitive, string and binary types") { factory => + testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) @@ -98,7 +104,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => + testBothCodegenAndInterpreted( + "basic conversion with primitive, string, date and timestamp types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) @@ -127,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - testWithFactory("null handling") { factory => + testBothCodegenAndInterpreted("null handling") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -248,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testWithFactory("NaN canonicalization") { factory => + testBothCodegenAndInterpreted("NaN canonicalization") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -263,7 +273,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - testWithFactory("basic conversion with struct type") { factory => + testBothCodegenAndInterpreted("basic conversion with struct type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) @@ -325,7 +336,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - testWithFactory("basic conversion with array type") { factory => + testBothCodegenAndInterpreted("basic conversion with array type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) @@ -355,7 +367,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - testWithFactory("basic conversion with map type") { factory => + testBothCodegenAndInterpreted("basic conversion with map type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) @@ -401,7 +414,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - testWithFactory("basic conversion with struct and array") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and array") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) @@ -440,7 +454,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with struct and map") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) @@ -486,7 +501,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with array and map") { factory => + testBothCodegenAndInterpreted("basic conversion with array and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bf588d3bb7841..c882a9dd2148c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -231,7 +231,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, schema, From df125062c8dac9fee3328d67dd438a456b7a3b74 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 23 May 2018 11:00:23 -0700 Subject: [PATCH 548/593] [SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? Change `PrefixSpan` into a class with param setter/getters. This address issues mentioned here: https://github.com/apache/spark/pull/20973#discussion_r186931806 ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21393 from WeichenXu123/fix_prefix_span. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 127 ++++++++++++++---- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 28 ++-- 2 files changed, 119 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 02168fee16dbf..41716c621ca98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -29,13 +31,97 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. * * @see Sequential Pattern Mining * (Wikipedia) */ @Since("2.4.0") @Experimental -object PrefixSpan { +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset)." + + "times will be output.", ParamValidators.gtEq(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** @group setParam */ + @Since("2.4.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * Param for the maximal pattern length (default: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "The maximal length of the sequential pattern.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Int = $(maxPatternLength) + + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. + * @group param + */ + @Since("2.4.0") + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") + + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") /** * :: Experimental :: @@ -43,54 +129,39 @@ object PrefixSpan { * * @param dataset A dataset or a dataframe containing a sequence column which is * {{{Seq[Seq[_]]}}} type - * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column - * are ignored - * @param minSupport the minimal support level of the sequential pattern, any pattern that - * appears more than (minSupport * size-of-the-dataset) times will be output - * (recommended value: `0.1`). - * @param maxPatternLength the maximal length of the sequential pattern - * (recommended value: `10`). - * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the - * internal storage format) allowed in a projected database before - * local processing. If a projected database exceeds this size, another - * iteration of distributed prefix growth is run - * (recommended value: `32000000`). * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: * - `sequence: Seq[Seq[T]]` (T is the item type) * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns( - dataset: Dataset[_], - sequenceCol: String, - minSupport: Double, - maxPatternLength: Int, - maxLocalProjDBSize: Long): DataFrame = { - - val inputType = dataset.schema(sequenceCol).dataType + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], s"The input column must be ArrayType and the array element type must also be ArrayType, " + s"but got $inputType.") - - val data = dataset.select(sequenceCol) - val sequences = data.where(col(sequenceCol).isNotNull).rdd + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) val mllibPrefixSpan = new mllibPrefixSpan() - .setMinSupport(minSupport) - .setMaxPatternLength(maxPatternLength) - .setMaxLocalProjDBSize(maxLocalProjDBSize) + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) val schema = StructType(Seq( - StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) freqSequences } + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala index 9e538696cbcf7..2252151af306b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan projections with multiple partial starts") { val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", - minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset) .as[(Seq[Seq[Int]], Long)].collect() val expected = Array( (Seq(Seq(1)), 1L), @@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan Integer type, variable-size itemsets") { val df = smallTestData.toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan input row with nulls") { val df = (smallTestData :+ null).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest { val df = smallTestData .map(seq => seq.map(itemSet => itemSet.map(intToString))) .toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[String]], Long)].collect() val expected = smallTestDataExpectedResult.map { case (seq, freq) => From 5a5a868dc410ad8c97851d7f3f0ea1c9fc1db90c Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 23 May 2018 11:51:13 -0700 Subject: [PATCH 549/593] Revert "[SPARK-24244][SQL] Passing only required columns to the CSV parser" This reverts commit 8086acc2f676a04ce6255a621ffae871bd09ceea. --- docs/sql-programming-guide.md | 1 - .../apache/spark/sql/internal/SQLConf.scala | 7 --- .../datasources/csv/CSVOptions.scala | 3 -- .../datasources/csv/UnivocityParser.scala | 26 +++++------ .../datasources/csv/CSVBenchmarks.scala | 42 ------------------ .../execution/datasources/csv/CSVSuite.scala | 43 +++---------------- 6 files changed, 18 insertions(+), 104 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fc26562ff33da..f1ed316341b95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,7 +1825,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fb98df587129b..15ba10f604510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1307,13 +1307,6 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } - - val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") - .internal() - .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + - "Other column values can be ignored during parsing even if they are malformed.") - .booleanConf - .createWithDefault(true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index dd41aee0f2ebc..1066d156acd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,7 +25,6 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -81,8 +80,6 @@ class CSVOptions( } } - private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) - val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f39..99557a1ceb0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - dataSchema: StructType, + schema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(dataSchema.toSet), + require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,17 +45,9 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { - val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) - parserSetting.selectIndexes(tokenIndexArr: _*) - } - new CsvParser(parserSetting) - } - private val schema = if (options.columnPruning) requiredSchema else dataSchema + private val tokenizer = new CsvParser(options.asParserSettings) - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -81,8 +73,11 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = { + private val valueConverters: Array[ValueConverter] = schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + + private val tokenIndexArr: Array[Int] = { + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -215,8 +210,9 @@ class UnivocityParser( } else { try { var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index ec788df00aa92..d442ba7e59c61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,49 +74,7 @@ object CSVBenchmarks { } } - def multiColumnsBenchmark(rowsNum: Int): Unit = { - val colsNum = 1000 - val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) - - withTempPath { path => - val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) - val schema = StructType(fields) - val values = (0 until colsNum).map(i => i.toString).mkString(",") - val columnNames = schema.fieldNames - - spark.range(rowsNum) - .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) - .write.option("header", true) - .csv(path.getAbsolutePath) - - val ds = spark.read.schema(schema).csv(path.getAbsolutePath) - - benchmark.addCase(s"Select $colsNum columns", 3) { _ => - ds.select("*").filter((row: Row) => true).count() - } - val cols100 = columnNames.take(100).map(Column(_)) - benchmark.addCase(s"Select 100 columns", 3) { _ => - ds.select(cols100: _*).filter((row: Row) => true).count() - } - benchmark.addCase(s"Select one column", 3) { _ => - ds.select($"col1").filter((row: Row) => true).count() - } - - /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz - - Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X - Select 100 columns 28625 / 32884 0.0 28625.1 2.7X - Select one column 22498 / 22669 0.0 22497.8 3.4X - */ - benchmark.run() - } - } - def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) - multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5f9f799a6c466..07e6c74b14d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,16 +260,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) - } + assert(cars.select("year").collect().size === 2) } } @@ -1370,31 +1368,4 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } - - test("SPARK-24244: Select a subset of all columns") { - withTempPath { path => - import collection.JavaConverters._ - val schema = new StructType() - .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) - .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) - .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) - .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) - .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) - - val odf = spark.createDataFrame(List( - Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) - ).asJava, schema) - odf.write.csv(path.getCanonicalPath) - val idf = spark.read - .schema(schema) - .csv(path.getCanonicalPath) - .select('f15, 'f10, 'f5) - - checkAnswer( - idf, - List(Row(15, 10, 5), Row(-15, -10, -5)) - ) - } - } } From 84557bc9f87885577f22d7e7f2220c0b988014cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 23 May 2018 13:02:32 -0700 Subject: [PATCH 550/593] [SPARK-24206][SQL] Improve DataSource read benchmark code ## What changes were proposed in this pull request? This pr added benchmark code `DataSourceReadBenchmark` for `orc`, `paruqet`, `csv`, and `json` based on the existing `ParquetReadBenchmark` and `OrcReadBenchmark`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21266 from maropu/DataSourceReadBenchmark. --- .../benchmark/DataSourceReadBenchmark.scala | 828 ++++++++++++++++++ .../parquet/ParquetReadBenchmark.scala | 339 ------- 2 files changed, 828 insertions(+), 339 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala new file mode 100644 index 0000000000000..fc6d8abc03c09 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -0,0 +1,828 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnVector +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure data source read performance. + * To run this: + * spark-submit --class + */ +object DataSourceReadBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceReadBenchmark") + .setIfMissing("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + spark.conf.set(SQLConf.ORC_COPY_BATCH_TO_SPARK.key, "false") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") + saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") + saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") + } + + private def saveAsCsvTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").option("header", true).csv(dir) + spark.read.option("header", true).csv(dir).createOrReplaceTempView("csvTable") + } + + private def saveAsJsonTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").json(dir) + spark.read.json(dir).createOrReplaceTempView("jsonTable") + } + + private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark( + s"Parquet Reader Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + sqlBenchmark.addCase("SQL Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + sqlBenchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 15231 / 15267 1.0 968.3 1.0X + SQL Json 8476 / 8498 1.9 538.9 1.8X + SQL Parquet Vectorized 121 / 127 130.0 7.7 125.9X + SQL Parquet MR 1515 / 1543 10.4 96.3 10.1X + SQL ORC Vectorized 164 / 171 95.9 10.4 92.9X + SQL ORC Vectorized with copy 228 / 234 69.0 14.5 66.8X + SQL ORC MR 1297 / 1309 12.1 82.5 11.7X + + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 16344 / 16374 1.0 1039.1 1.0X + SQL Json 8634 / 8648 1.8 548.9 1.9X + SQL Parquet Vectorized 172 / 177 91.5 10.9 95.1X + SQL Parquet MR 1744 / 1746 9.0 110.9 9.4X + SQL ORC Vectorized 189 / 194 83.1 12.0 86.4X + SQL ORC Vectorized with copy 244 / 250 64.5 15.5 67.0X + SQL ORC MR 1341 / 1386 11.7 85.3 12.2X + + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17874 / 17875 0.9 1136.4 1.0X + SQL Json 9190 / 9204 1.7 584.3 1.9X + SQL Parquet Vectorized 141 / 160 111.2 9.0 126.4X + SQL Parquet MR 1930 / 2049 8.2 122.7 9.3X + SQL ORC Vectorized 259 / 264 60.7 16.5 69.0X + SQL ORC Vectorized with copy 265 / 272 59.4 16.8 67.5X + SQL ORC MR 1528 / 1569 10.3 97.2 11.7X + + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 22812 / 22839 0.7 1450.4 1.0X + SQL Json 12026 / 12054 1.3 764.6 1.9X + SQL Parquet Vectorized 222 / 227 70.8 14.1 102.6X + SQL Parquet MR 2199 / 2204 7.2 139.8 10.4X + SQL ORC Vectorized 331 / 335 47.6 21.0 69.0X + SQL ORC Vectorized with copy 338 / 343 46.6 21.5 67.6X + SQL ORC MR 1618 / 1622 9.7 102.9 14.1X + + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 18703 / 18740 0.8 1189.1 1.0X + SQL Json 11779 / 11869 1.3 748.9 1.6X + SQL Parquet Vectorized 143 / 145 110.1 9.1 130.9X + SQL Parquet MR 1954 / 1963 8.0 124.2 9.6X + SQL ORC Vectorized 347 / 355 45.3 22.1 53.8X + SQL ORC Vectorized with copy 356 / 359 44.1 22.7 52.5X + SQL ORC MR 1570 / 1598 10.0 99.8 11.9X + + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 23832 / 23838 0.7 1515.2 1.0X + SQL Json 16204 / 16226 1.0 1030.2 1.5X + SQL Parquet Vectorized 242 / 306 65.1 15.4 98.6X + SQL Parquet MR 2462 / 2482 6.4 156.5 9.7X + SQL ORC Vectorized 419 / 451 37.6 26.6 56.9X + SQL ORC Vectorized with copy 426 / 447 36.9 27.1 55.9X + SQL ORC MR 1885 / 1931 8.3 119.8 12.6X + */ + sqlBenchmark.run() + + // Driving the parquet reader in batch mode directly. + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) + case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) + case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) + case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) + case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) + case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) aggregateValue(col, i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (InternalRow) => Unit = dataType match { + case ByteType => (col: InternalRow) => longSum += col.getByte(0) + case ShortType => (col: InternalRow) => longSum += col.getShort(0) + case IntegerType => (col: InternalRow) => longSum += col.getInt(0) + case LongType => (col: InternalRow) => longSum += col.getLong(0) + case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) + case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) aggregateValue(record) + } + } + } finally { + reader.close() + } + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 187 / 201 84.2 11.9 1.0X + ParquetReader Vectorized -> Row 101 / 103 156.4 6.4 1.9X + + + Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 272 / 288 57.8 17.3 1.0X + ParquetReader Vectorized -> Row 213 / 219 73.7 13.6 1.3X + + + Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 252 / 288 62.5 16.0 1.0X + ParquetReader Vectorized -> Row 232 / 246 67.7 14.8 1.1X + + + Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 415 / 454 37.9 26.4 1.0X + ParquetReader Vectorized -> Row 407 / 432 38.6 25.9 1.0X + + + Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 251 / 302 62.7 16.0 1.0X + ParquetReader Vectorized -> Row 220 / 234 71.5 14.0 1.1X + + + Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 432 / 436 36.4 27.5 1.0X + ParquetReader Vectorized -> Row 414 / 422 38.0 26.4 1.0X + */ + parquetReaderBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(c1), sum(length(c2)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(c1), sum(length(c2)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 19172 / 19173 0.5 1828.4 1.0X + SQL Json 12799 / 12873 0.8 1220.6 1.5X + SQL Parquet Vectorized 2558 / 2564 4.1 244.0 7.5X + SQL Parquet MR 4514 / 4583 2.3 430.4 4.2X + SQL ORC Vectorized 2561 / 2697 4.1 244.3 7.5X + SQL ORC Vectorized with copy 3076 / 3110 3.4 293.4 6.2X + SQL ORC MR 4197 / 4283 2.5 400.2 4.6X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("select cast((value % 200) + 10000 as STRING) as c1 from t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c1)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c1)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("select sum(length(c1)) from orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 10889 / 10924 1.0 1038.5 1.0X + SQL Json 7903 / 7931 1.3 753.7 1.4X + SQL Parquet Vectorized 777 / 799 13.5 74.1 14.0X + SQL Parquet MR 1682 / 1708 6.2 160.4 6.5X + SQL ORC Vectorized 532 / 534 19.7 50.7 20.5X + SQL ORC Vectorized with copy 742 / 743 14.1 70.7 14.7X + SQL ORC MR 1996 / 2002 5.3 190.4 5.5X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column - CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + benchmark.addCase("Data column - Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + benchmark.addCase("Data column - Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + benchmark.addCase("Data column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + benchmark.addCase("Data column - ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Data column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Data column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - CSV") { _ => + spark.sql("select sum(p) from csvTable").collect() + } + + benchmark.addCase("Partition column - Json") { _ => + spark.sql("select sum(p) from jsonTable").collect() + } + + benchmark.addCase("Partition column - Parquet Vectorized") { _ => + spark.sql("select sum(p) from parquetTable").collect() + } + + benchmark.addCase("Partition column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p) from parquetTable").collect() + } + } + + benchmark.addCase("Partition column - ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + + benchmark.addCase("Partition column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - CSV") { _ => + spark.sql("select sum(p), sum(id) from csvTable").collect() + } + + benchmark.addCase("Both columns - Json") { _ => + spark.sql("select sum(p), sum(id) from jsonTable").collect() + } + + benchmark.addCase("Both columns - Parquet Vectorized") { _ => + spark.sql("select sum(p), sum(id) from parquetTable").collect() + } + + benchmark.addCase("Both columns - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p), sum(id) from parquetTable").collect + } + } + + benchmark.addCase("Both columns - ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Both column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Data column - CSV 25428 / 25454 0.6 1616.7 1.0X + Data column - Json 12689 / 12774 1.2 806.7 2.0X + Data column - Parquet Vectorized 222 / 231 70.7 14.1 114.3X + Data column - Parquet MR 3355 / 3397 4.7 213.3 7.6X + Data column - ORC Vectorized 332 / 338 47.4 21.1 76.6X + Data column - ORC Vectorized with copy 338 / 341 46.5 21.5 75.2X + Data column - ORC MR 2329 / 2356 6.8 148.0 10.9X + Partition column - CSV 17465 / 17502 0.9 1110.4 1.5X + Partition column - Json 10865 / 10876 1.4 690.8 2.3X + Partition column - Parquet Vectorized 48 / 52 325.4 3.1 526.1X + Partition column - Parquet MR 1695 / 1696 9.3 107.8 15.0X + Partition column - ORC Vectorized 49 / 54 319.9 3.1 517.2X + Partition column - ORC Vectorized with copy 49 / 52 324.1 3.1 524.0X + Partition column - ORC MR 1548 / 1549 10.2 98.4 16.4X + Both columns - CSV 25568 / 25595 0.6 1625.6 1.0X + Both columns - Json 13658 / 13673 1.2 868.4 1.9X + Both columns - Parquet Vectorized 270 / 296 58.3 17.1 94.3X + Both columns - Parquet MR 3501 / 3521 4.5 222.6 7.3X + Both columns - ORC Vectorized 377 / 380 41.7 24.0 67.4X + Both column - ORC Vectorized with copy 447 / 448 35.2 28.4 56.9X + Both columns - ORC MR 2440 / 2446 6.4 155.2 10.4X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + val benchmark = new Benchmark("String with Nulls Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c2)) from csvTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c2)) from jsonTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + benchmark.addCase("ParquetReader Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 13518 / 13529 0.8 1289.2 1.0X + SQL Json 10895 / 10926 1.0 1039.0 1.2X + SQL Parquet Vectorized 1539 / 1581 6.8 146.8 8.8X + SQL Parquet MR 3746 / 3811 2.8 357.3 3.6X + ParquetReader Vectorized 1070 / 1112 9.8 102.0 12.6X + SQL ORC Vectorized 1389 / 1408 7.6 132.4 9.7X + SQL ORC Vectorized with copy 1736 / 1750 6.0 165.6 7.8X + SQL ORC MR 3799 / 3892 2.8 362.3 3.6X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 10854 / 10892 1.0 1035.2 1.0X + SQL Json 8129 / 8138 1.3 775.3 1.3X + SQL Parquet Vectorized 1053 / 1104 10.0 100.4 10.3X + SQL Parquet MR 2840 / 2854 3.7 270.8 3.8X + ParquetReader Vectorized 978 / 1008 10.7 93.2 11.1X + SQL ORC Vectorized 1312 / 1387 8.0 125.1 8.3X + SQL ORC Vectorized with copy 1764 / 1772 5.9 168.2 6.2X + SQL ORC MR 3435 / 3445 3.1 327.6 3.2X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 8043 / 8048 1.3 767.1 1.0X + SQL Json 4911 / 4923 2.1 468.4 1.6X + SQL Parquet Vectorized 206 / 209 51.0 19.6 39.1X + SQL Parquet MR 1528 / 1537 6.9 145.8 5.3X + ParquetReader Vectorized 216 / 219 48.6 20.6 37.2X + SQL ORC Vectorized 462 / 466 22.7 44.1 17.4X + SQL ORC Vectorized with copy 568 / 572 18.5 54.2 14.2X + SQL ORC MR 1647 / 1649 6.4 157.1 4.9X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql(s"SELECT sum(c$middle) FROM csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql(s"SELECT sum(c$middle) FROM jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 3663 / 3665 0.3 3493.2 1.0X + SQL Json 3122 / 3160 0.3 2977.5 1.2X + SQL Parquet Vectorized 40 / 42 26.2 38.2 91.5X + SQL Parquet MR 189 / 192 5.5 180.2 19.4X + SQL ORC Vectorized 48 / 51 21.6 46.2 75.6X + SQL ORC Vectorized with copy 49 / 52 21.4 46.7 74.9X + SQL ORC MR 280 / 289 3.7 267.1 13.1X + + + Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 11420 / 11505 0.1 10891.1 1.0X + SQL Json 11905 / 12120 0.1 11353.6 1.0X + SQL Parquet Vectorized 50 / 54 20.9 47.8 227.7X + SQL Parquet MR 195 / 199 5.4 185.8 58.6X + SQL ORC Vectorized 61 / 65 17.3 57.8 188.3X + SQL ORC Vectorized with copy 62 / 65 17.0 58.8 185.2X + SQL ORC MR 847 / 865 1.2 807.4 13.5X + + + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 21278 / 21404 0.0 20292.4 1.0X + SQL Json 22455 / 22625 0.0 21414.7 0.9X + SQL Parquet Vectorized 73 / 75 14.4 69.3 292.8X + SQL Parquet MR 220 / 226 4.8 209.7 96.8X + SQL ORC Vectorized 82 / 86 12.8 78.2 259.4X + SQL ORC Vectorized with copy 82 / 90 12.7 78.7 258.0X + SQL ORC MR 1568 / 1582 0.7 1495.4 13.6X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + repeatedStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + for (columnWidth <- List(10, 50, 100)) { + columnsBenchmark(1024 * 1024 * 1, columnWidth) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala deleted file mode 100644 index e43336d947364..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.parquet - -import java.io.File - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - -/** - * Benchmark to measure parquet read performance. - * To run this: - * spark-submit --class --jars - */ -object ParquetReadBenchmark { - val conf = new SparkConf() - conf.set("spark.sql.parquet.compression.codec", "snappy") - - val spark = SparkSession.builder - .master("local[1]") - .appName("test-sql-context") - .config(conf) - .getOrCreate() - - // Set default configs. Individual cases will change them if necessary. - spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - def intScanBenchmark(values: Int): Unit = { - // Benchmarks running through spark sql. - val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) - // Benchmarks driving reader component directly. - val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) - - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as id from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(id) from tempTable").collect() - } - - sqlBenchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from tempTable").collect() - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - // Driving the parquet reader in batch mode directly. - parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - val col = batch.column(0) - while (reader.nextBatch()) { - val numRows = batch.numRows() - var i = 0 - while (i < numRows) { - if (!col.isNullAt(i)) sum += col.getInt(i) - i += 1 - } - } - } finally { - reader.close() - } - } - } - - // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val it = batch.rowIterator() - while (it.hasNext) { - val record = it.next() - if (!record.isNullAt(0)) sum += record.getInt(0) - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X - SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X - */ - sqlBenchmark.run() - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X - ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X - */ - parquetReaderBenchmark.run() - } - } - } - - def intStringScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Int and String Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X - SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X - */ - benchmark.run() - } - } - } - - def stringDictionaryScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String Dictionary", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c1)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c1)) from tempTable").collect - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X - SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X - */ - benchmark.run() - } - } - } - - def partitionTableScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select id % 2 as p, cast(id as INT) as id from t1") - .write.partitionBy("p").parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Partitioned Table", values) - - benchmark.addCase("Read data column") { iter => - spark.sql("select sum(id) from tempTable").collect - } - - benchmark.addCase("Read partition column") { iter => - spark.sql("select sum(p) from tempTable").collect - } - - benchmark.addCase("Read both columns") { iter => - spark.sql("select sum(p), sum(id) from tempTable").collect - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Read data column 191 / 250 82.1 12.2 1.0X - Read partition column 82 / 86 192.4 5.2 2.3X - Read both columns 220 / 248 71.5 14.0 0.9X - */ - benchmark.run() - } - } - } - - def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + - s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String with Nulls Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c2)) from tempTable where c1 is " + - "not NULL and c2 is not NULL").collect() - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("PR Vectorized") { num => - var sum = 0 - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - val row = rowIterator.next() - val value = row.getUTF8String(0) - if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X - PR Vectorized 833 / 846 12.6 79.4 1.5X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X - PR Vectorized 732 / 772 14.3 69.8 1.4X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X - PR Vectorized 190 / 200 55.1 18.2 1.7X - */ - - benchmark.run() - } - } - } - - def main(args: Array[String]): Unit = { - intScanBenchmark(1024 * 1024 * 15) - intStringScanBenchmark(1024 * 1024 * 10) - stringDictionaryScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) - } - } -} From b7a036b75b8a1d287ac014b85e90d555753064c9 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 23 May 2018 13:12:05 -0700 Subject: [PATCH 551/593] [SPARK-24294] Throw SparkException when OOM in BroadcastExchangeExec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When OutOfMemoryError thrown from BroadcastExchangeExec, scala.concurrent.Future will hit scala bug – https://github.com/scala/bug/issues/9554, and hang until future timeout: We could wrap the OOM inside SparkException to resolve this issue. ## How was this patch tested? Manually tested. Author: jinxing Closes #21342 from jinxing64/SPARK-24294. --- .../spark/util/SparkFatalException.scala | 27 +++++++++++++++++++ .../util/SparkUncaughtExceptionHandler.scala | 13 ++++++--- .../org/apache/spark/util/ThreadUtils.scala | 2 ++ .../exchange/BroadcastExchangeExec.scala | 13 ++++++--- 4 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/SparkFatalException.scala diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala new file mode 100644 index 0000000000000..1aa2009fa9b5b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we catch + * fatal throwable in {@link scala.concurrent.Future}'s body, and re-throw + * SparkFatalException, which wraps the fatal throwable inside. + * Note that SparkFatalException should only be thrown from a {@link scala.concurrent.Future}, + * which is run by using ThreadUtils.awaitResult. ThreadUtils.awaitResult will catch + * it and re-throw the original exception/error. + */ +private[spark] final class SparkFatalException(val throwable: Throwable) extends Exception diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index e0f5af5250e7f..1b34fbde38cd6 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -39,10 +39,15 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) if (!ShutdownHookManager.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(SparkExitCode.OOM) - } else if (exitOnUncaughtException) { - System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + exception match { + case _: OutOfMemoryError => + System.exit(SparkExitCode.OOM) + case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => + // SPARK-24294: This is defensive code, in case that SparkFatalException is + // misused and uncaught. + System.exit(SparkExitCode.OOM) + case _ if exitOnUncaughtException => + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 81aaf79db0c13..165a15c73e7ca 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -200,6 +200,8 @@ private[spark] object ThreadUtils { val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] awaitable.result(atMost)(awaitPermission) } catch { + case e: SparkFatalException => + throw e.throwable // TimeoutException is thrown in the current thread, so not need to warp the exception. case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9e0ec9481b0de..c55f9b8f1a7fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} import org.apache.spark.launcher.SparkLauncher @@ -30,7 +31,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of @@ -111,12 +112,18 @@ case class BroadcastExchangeExec( SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + + throw new SparkFatalException( + new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + s"all worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " + s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value") - .initCause(oe.getCause) + .initCause(oe.getCause)) + case e if !NonFatal(e) => + throw new SparkFatalException(e) } } }(BroadcastExchangeExec.executionContext) From f4579332931c9bf424d0b6147fad89bd63da26f6 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 23 May 2018 17:21:29 -0700 Subject: [PATCH 552/593] [SPARK-23416][SS] Add a specific stop method for ContinuousExecution. ## What changes were proposed in this pull request? Add a specific stop method for ContinuousExecution. The previous StreamExecution.stop() method had a race condition as applied to continuous processing: if the cancellation was round-tripped to the driver too quickly, the generic SparkException it caused would be reported as the query death cause. We earlier decided that SparkException should not be added to the StreamExecution.isInterruptionException() whitelist, so we need to ensure this never happens instead. ## How was this patch tested? Existing tests. I could consistently reproduce the previous flakiness by putting Thread.sleep(1000) between the first job cancellation and thread interruption in StreamExecution.stop(). Author: Jose Torres Closes #21384 from jose-torres/fixKafka. --- .../streaming/MicroBatchExecution.scala | 18 ++++++++++++++++++ .../execution/streaming/StreamExecution.scala | 18 ------------------ .../continuous/ContinuousExecution.scala | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6709e7052f005..7817360810bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -126,6 +126,24 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + queryExecutionThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + logInfo(s"Query $prettyIdString was stopped") + } + /** * Repeatedly attempts to run batches as data arrives. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3fc8c7887896a..290de873c5cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -378,24 +378,6 @@ abstract class StreamExecution( } } - /** - * Signals to the thread executing micro-batches that it should stop running after the next - * batch. This method blocks until the thread stops running. - */ - override def stop(): Unit = { - // Set the state to TERMINATED so that the batching thread knows that it was interrupted - // intentionally - state.set(TERMINATED) - if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) - queryExecutionThread.interrupt() - queryExecutionThread.join() - // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak - sparkSession.sparkContext.cancelJobGroup(runId.toString) - } - logInfo(s"Query $prettyIdString was stopped") - } - /** * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 0e7d1019b9c8f..d16b24c89ebef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -356,6 +356,22 @@ class ContinuousExecution( } } } + + /** + * Stops the query execution thread to terminate the query. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + // The query execution thread will clean itself up in the finally clause of runContinuous. + // We just need to interrupt the long running job. + queryExecutionThread.interrupt() + queryExecutionThread.join() + } + logInfo(s"Query $prettyIdString was stopped") + } } object ContinuousExecution { From 230f1441978e14062d3e0b6ba1524acf36e3eafe Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Wed, 23 May 2018 17:22:52 -0700 Subject: [PATCH 553/593] [SPARK-24350][SQL] Fixes ClassCastException in the "array_position" function ## What changes were proposed in this pull request? ### Fixes `ClassCastException` in the `array_position` function - [SPARK-24350](https://issues.apache.org/jira/browse/SPARK-24350) When calling `array_position` function with a wrong type of the 1st argument an `AnalysisException` should be thrown instead of `ClassCastException` Example: ```sql select array_position('foo', 'bar') ``` ``` java.lang.ClassCastException: org.apache.spark.sql.types.StringType$ cannot be cast to org.apache.spark.sql.types.ArrayType at org.apache.spark.sql.catalyst.expressions.ArrayPosition.inputTypes(collectionOperations.scala:1398) at org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes$class.checkInputDataTypes(ExpectsInputTypes.scala:44) at org.apache.spark.sql.catalyst.expressions.ArrayPosition.checkInputDataTypes(collectionOperations.scala:1401) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:168) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:168) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:256) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:252) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:288) ``` ## How was this patch tested? unit test Author: Vayda, Oleksandr: IT (PRG) Closes #21401 from wajda/SPARK-24350-array_position-error-fix. --- .../catalyst/expressions/collectionOperations.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 03b3b21a16617..8a877b02c8191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1395,8 +1395,14 @@ case class ArrayPosition(left: Expression, right: Expression) TypeUtils.getInterpretedOrdering(right.dataType) override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ec2a569f900d1..79e743d961af8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -708,6 +708,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") { From 888340151f737bb68d0e419b1e949f11469881f9 Mon Sep 17 00:00:00 2001 From: sychen Date: Thu, 24 May 2018 11:02:09 +0800 Subject: [PATCH 554/593] [SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted. Author: sychen Closes #21311 from cxzl25/fix_LongToUnsafeRowMap_page_size. --- .../sql/execution/joins/HashedRelation.scala | 38 +++++++++++-------- .../execution/joins/HashedRelationSuite.scala | 26 ++++++++++++- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1465346eb802d..20ce01f4ce8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 51f8c3325fdff..037cc2e3ccad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() From 486ecc680e9a0e7b6b3c3a45fb883a61072096fc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 24 May 2018 11:34:13 +0800 Subject: [PATCH 555/593] [SPARK-24322][BUILD] Upgrade Apache ORC to 1.4.4 ## What changes were proposed in this pull request? ORC 1.4.4 includes [nine fixes](https://issues.apache.org/jira/issues/?filter=12342568&jql=project%20%3D%20ORC%20AND%20resolution%20%3D%20Fixed%20AND%20fixVersion%20%3D%201.4.4). One of the issues is about `Timestamp` bug (ORC-306) which occurs when `native` ORC vectorized reader reads ORC column vector's sub-vector `times` and `nanos`. ORC-306 fixes this according to the [original definition](https://github.com/apache/hive/blob/master/storage-api/src/java/org/apache/hadoop/hive/ql/exec/vector/TimestampColumnVector.java#L45-L46) and this PR includes the updated interpretation on ORC column vectors. Note that `hive` ORC reader and ORC MR reader is not affected. ```scala scala> spark.version res0: String = 2.3.0 scala> spark.sql("set spark.sql.orc.impl=native") scala> Seq(java.sql.Timestamp.valueOf("1900-05-05 12:34:56.000789")).toDF().write.orc("/tmp/orc") scala> spark.read.orc("/tmp/orc").show(false) +--------------------------+ |value | +--------------------------+ |1900-05-05 12:34:55.000789| +--------------------------+ ``` This PR aims to update Apache Spark to use it. **FULL LIST** ID | TITLE -- | -- ORC-281 | Fix compiler warnings from clang 5.0 ORC-301 | `extractFileTail` should open a file in `try` statement ORC-304 | Fix TestRecordReaderImpl to not fail with new storage-api ORC-306 | Fix incorrect workaround for bug in java.sql.Timestamp ORC-324 | Add support for ARM and PPC arch ORC-330 | Remove unnecessary Hive artifacts from root pom ORC-332 | Add syntax version to orc_proto.proto ORC-336 | Remove avro and parquet dependency management entries ORC-360 | Implement error checking on subtype fields in Java ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21372 from dongjoon-hyun/SPARK_ORC144. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- .../sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../datasources/orc/OrcColumnarBatchReader.java | 2 +- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ 7 files changed, 18 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e710e26348117..723180a14febb 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 97ad17a9ff7b1..ea08a001a1c9b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e21bfef8c4291..da874026d7d10 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -176,8 +176,8 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 6e37e518d86e4..883c096ae1ae9 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.10.0 - 1.4.3 + 1.4.4 nohive 1.6.0 9.3.20.v20170531 diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..9bfad1e83ee7b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -136,7 +136,7 @@ public int getInt(int rowId) { public long getLong(int rowId) { int index = getRowIndex(rowId); if (isTimestamp) { - return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; } else { return longData.vector[index]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index dcebdc39f0aa2..a0d9578a377b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -497,7 +497,7 @@ private void putValues( * Returns the number of micros since epoch from an element of TimestampColumnVector. */ private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { - return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000); } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8a3bbd03a26dc..02bfb7197ffc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.sql.Timestamp import java.util.Locale import org.apache.orc.OrcConf.COMPRESS @@ -169,6 +170,14 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-24322 Fix incorrect workaround for bug in java.sql.Timestamp") { + withTempPath { path => + val ts = Timestamp.valueOf("1900-05-05 12:34:56.000789") + Seq(ts).toDF.write.orc(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { From e108f84f5cd562f070872651bdcf6c02e80dd585 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 24 May 2018 11:42:25 +0800 Subject: [PATCH 556/593] [MINOR][CORE] Cleanup unused vals in `DAGScheduler.handleTaskCompletion` ## What changes were proposed in this pull request? Cleanup unused vals in `DAGScheduler.handleTaskCompletion` to reduce the code complexity slightly. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #21406 from jiangxb1987/handleTaskCompletion. --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ea7bfd7d7a68d..041eade82d3ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1167,9 +1167,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task - val taskId = event.taskInfo.id val stageId = task.stageId - val taskType = Utils.getFormattedClassName(task) outputCommitCoordinator.taskCompleted( stageId, @@ -1323,7 +1321,7 @@ class DAGScheduler( "tasks in ShuffleMapStages.") } - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => + case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1411,7 +1409,7 @@ class DAGScheduler( } } - case commitDenied: TaskCommitDenied => + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case _: ExceptionFailure | _: TaskKilled => From 8a545822d0cc3a866ef91a94e58ea5c8b1014007 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 May 2018 13:21:02 +0800 Subject: [PATCH 557/593] [SPARK-24364][SS] Prevent InMemoryFileIndex from failing if file path doesn't exist ## What changes were proposed in this pull request? This PR proposes to follow up https://github.com/apache/spark/pull/15153 and complete SPARK-17599. `FileSystem` operation (`fs.getFileBlockLocations`) can still fail if the file path does not exist. For example see the exception message below: ``` Error occurred while processing: File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... java.io.FileNotFoundException: File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... org.apache.hadoop.hdfs.DistributedFileSystem.getFileBlockLocations(DistributedFileSystem.java:249) at org.apache.hadoop.hdfs.DistributedFileSystem.getFileBlockLocations(DistributedFileSystem.java:229) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles$3.apply(InMemoryFileIndex.scala:314) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles$3.apply(InMemoryFileIndex.scala:297) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles(InMemoryFileIndex.scala:297) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles$1.apply(InMemoryFileIndex.scala:174) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles$1.apply(InMemoryFileIndex.scala:173) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles(InMemoryFileIndex.scala:173) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.listLeafFiles(InMemoryFileIndex.scala:126) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.refresh0(InMemoryFileIndex.scala:91) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.(InMemoryFileIndex.scala:67) at org.apache.spark.sql.execution.datasources.DataSource.tempFileIndex$lzycompute$1(DataSource.scala:161) at org.apache.spark.sql.execution.datasources.DataSource.org$apache$spark$sql$execution$datasources$DataSource$$tempFileIndex$1(DataSource.scala:152) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:166) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:261) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:94) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:94) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:33) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:196) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:206) at com.hwx.StreamTest$.main(StreamTest.scala:97) at com.hwx.StreamTest.main(StreamTest.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52) at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:906) at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197) at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: org.apache.hadoop.ipc.RemoteException(java.io.FileNotFoundException): File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... ``` So, it fixes it to make a warning instead. ## How was this patch tested? It's hard to write a test. Manually tested multiple times. Author: hyukjinkwon Closes #21408 from HyukjinKwon/missing-files. --- .../datasources/InMemoryFileIndex.scala | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 739d1f456e3ec..9d9f8bd5bb58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -294,9 +294,12 @@ object InMemoryFileIndex extends Logging { if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles } - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + val missingFiles = mutable.ArrayBuffer.empty[String] + val filteredLeafStatuses = allLeafStatuses.filterNot( + status => shouldFilterOut(status.getPath.getName)) + val resolvedLeafStatuses = filteredLeafStatuses.flatMap { case f: LocatedFileStatus => - f + Some(f) // NOTE: // @@ -311,14 +314,27 @@ object InMemoryFileIndex extends Logging { // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), // which is very slow on some file system (RawLocalFileSystem, which is launch a // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) + try { + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + Some(lfs) + } catch { + case _: FileNotFoundException => + missingFiles += f.getPath.toString + None } - lfs } + + if (missingFiles.nonEmpty) { + logWarning( + s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") + } + + resolvedLeafStatuses } /** Checks if we should filter out this path name. */ From 4a14dc0aff9cac85390cab94bc183271fa95beef Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 May 2018 14:19:32 +0800 Subject: [PATCH 558/593] [SPARK-22269][BUILD] Run Java linter via SBT for Jenkins ## What changes were proposed in this pull request? This PR proposes to check Java lint via SBT for Jenkins. It uses the SBT wrapper for checkstyle. I manually tested. If we build the codes once, running this script takes 2 mins at maximum in my local: Test codes: ``` Checkstyle failed at following occurrences: [error] Checkstyle error found in /.../spark/core/src/test/java/test/org/apache/spark/JavaAPISuite.java:82: Line is longer than 100 characters (found 103). [error] 1 issue(s) found in Checkstyle report: /.../spark/core/target/checkstyle-test-report.xml [error] Checkstyle error found in /.../spark/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java:84: Line is longer than 100 characters (found 115). [error] 1 issue(s) found in Checkstyle report: /.../spark/sql/hive/target/checkstyle-test-report.xml ... ``` Main codes: ``` Checkstyle failed at following occurrences: [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java:39: Line is longer than 100 characters (found 104). [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:26: Line is longer than 100 characters (found 110). [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:30: Line is longer than 100 characters (found 104). ... ``` ## How was this patch tested? Manually tested. Jenkins build should test this. Author: hyukjinkwon Closes #21399 from HyukjinKwon/SPARK-22269. --- dev/run-tests.py | 5 ++--- dev/sbt-checkstyle | 42 ++++++++++++++++++++++++++++++++++++++++ project/SparkBuild.scala | 14 +++++++++++++- project/plugins.sbt | 8 ++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100755 dev/sbt-checkstyle diff --git a/dev/run-tests.py b/dev/run-tests.py index 164c1e2200aa9..5e8c8590b5c34 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -204,7 +204,7 @@ def run_scala_style_checks(): def run_java_style_checks(): set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") - run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "sbt-checkstyle")]) def run_python_style_checks(): @@ -574,8 +574,7 @@ def main(): or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - # run_java_style_checks() - pass + run_java_style_checks() if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle new file mode 100755 index 0000000000000..8821a7c0e4ccf --- /dev/null +++ b/dev/sbt-checkstyle @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pmesos \ + -Pkafka-0-8 \ + -Pkubernetes \ + -Pyarn \ + -Pflume \ + -Phive \ + -Phive-thriftserver \ + checkstyle test:checkstyle \ + | awk '{if($1~/error/)print}' \ +) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7469f11df0294..4cb6495a33b61 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -27,6 +27,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._ import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -317,7 +318,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -740,6 +741,17 @@ object Unidoc { ) } +object Checkstyle { + lazy val settings = Seq( + checkstyleSeverityLevel := Some(CheckstyleSeverityLevel.Error), + javaSource in (Compile, checkstyle) := baseDirectory.value / "src/main/java", + javaSource in (Test, checkstyle) := baseDirectory.value / "src/test/java", + checkstyleConfigLocation := CheckstyleConfigLocation.File("dev/checkstyle.xml"), + checkstyleOutputFile := baseDirectory.value / "target/checkstyle-output.xml", + checkstyleOutputFile in Test := baseDirectory.value / "target/checkstyle-output.xml" + ) +} + object CopyDependencies { val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") diff --git a/project/plugins.sbt b/project/plugins.sbt index 96bdb9067ae59..ffbd417b0f145 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,11 @@ +addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") + +// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" + +// checkstyle uses guava 23.0. +libraryDependencies += "com.google.guava" % "guava" % "23.0" + // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") From 3469f5c989e686866051382a3a28b2265619cab9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 24 May 2018 20:55:26 +0800 Subject: [PATCH 559/593] [SPARK-24230][SQL] Fix SpecificParquetRecordReaderBase with dictionary filters. ## What changes were proposed in this pull request? I missed this commit when preparing #21070. When Parquet is able to filter blocks with dictionary filtering, the expected total value count to be too high in Spark, leading to an error when there were fewer than expected row groups to process. Spark should get the row groups from Parquet to pick up new filter schemes in Parquet like dictionary filtering. ## How was this patch tested? Using in production at Netflix. Added test case for dictionary-filtered blocks. Author: Ryan Blue Closes #21295 from rdblue/SPARK-24230-fix-parquet-block-tracking. --- .../parquet/SpecificParquetRecordReaderBase.java | 6 ++++-- .../datasources/parquet/ParquetQuerySuite.scala | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index daedfd7e78f5f..c975e52734e01 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -146,7 +146,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader( configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } @@ -224,7 +225,8 @@ protected void initialize(String path, List columns) throws IOException this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e1f094d0a7af3..2b1227faf48a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -879,6 +879,18 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-24230: filter row group using dictionary") { + withSQLConf(("parquet.filter.dictionary.enabled", "true")) { + // create a table with values from 0, 2, ..., 18 that will be dictionary-encoded + withParquetTable((0 until 100).map(i => ((i * 2) % 20, s"data-$i")), "t") { + // search for a key that is not present so the dictionary filter eliminates all row groups + // Fails without SPARK-24230: + // java.io.IOException: expecting more rows but reached last block. Read 0 out of 50 + checkAnswer(sql("SELECT _2 FROM t WHERE t._1 = 5"), Seq.empty) + } + } + } } object TestingUDT { From 13bedc05c28fcc6e739fb472bd2ee3035fa11648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 24 May 2018 22:18:58 +0800 Subject: [PATCH 560/593] [SPARK-24329][SQL] Test for skipping multi-space lines ## What changes were proposed in this pull request? The PR is a continue of https://github.com/apache/spark/pull/21380 . It checks cases that are handled by the code: https://github.com/apache/spark/blob/e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala#L303-L304 Basically the code skips lines with one or many whitespaces, and lines with comments (see [filterCommentAndEmpty](https://github.com/apache/spark/blob/e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala#L47)) ```scala iter.filter { line => line.trim.nonEmpty && !line.startsWith(options.comment.toString) } ``` Closes #21380 ## How was this patch tested? Added a test for the case described above. Author: Maxim Gekk Author: Maxim Gekk Closes #21394 from MaxGekk/test-for-multi-space-lines. --- .../resources/test-data/comments-whitespaces.csv | 8 ++++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 15 +++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/comments-whitespaces.csv diff --git a/sql/core/src/test/resources/test-data/comments-whitespaces.csv b/sql/core/src/test/resources/test-data/comments-whitespaces.csv new file mode 100644 index 0000000000000..2737978f83a5e --- /dev/null +++ b/sql/core/src/test/resources/test-data/comments-whitespaces.csv @@ -0,0 +1,8 @@ +# The file contains comments, whitespaces and empty lines +colA +# empty line + +# the line with a few whitespaces + +# int value with leading and trailing whitespaces + "a" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07e6c74b14d0d..2bac1a3e2d4c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1368,4 +1368,19 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } + + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { + val schema = new StructType().add("colA", StringType) + val ds = spark + .read + .schema(schema) + .option("multiLine", false) + .option("header", true) + .option("comment", "#") + .option("ignoreLeadingWhiteSpace", false) + .option("ignoreTrailingWhiteSpace", false) + .csv(testFile("test-data/comments-whitespaces.csv")) + + checkAnswer(ds, Seq(Row(""" "a" """))) + } } From 0d89943449764e8a578edf4ceb6245158421eb96 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 24 May 2018 23:38:50 +0800 Subject: [PATCH 561/593] [SPARK-24378][SQL] Fix date_trunc function incorrect examples ## What changes were proposed in this pull request? Fix `date_trunc` function incorrect examples. ## How was this patch tested? N/A Author: Yuming Wang Closes #21423 from wangyum/SPARK-24378. --- .../expressions/datetimeExpressions.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index e8d85f72f7a7a..08838d2b2c612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1533,14 +1533,14 @@ case class TruncDate(date: Expression, format: Expression) """, examples = """ Examples: - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); - 2015-01-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); - 2015-03-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); - 2015-03-05T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); - 2015-03-05T09:00:00 + > SELECT _FUNC_('YEAR', '2015-03-05T09:32:05.359'); + 2015-01-01 00:00:00 + > SELECT _FUNC_('MM', '2015-03-05T09:32:05.359'); + 2015-03-01 00:00:00 + > SELECT _FUNC_('DD', '2015-03-05T09:32:05.359'); + 2015-03-05 00:00:00 + > SELECT _FUNC_('HOUR', '2015-03-05T09:32:05.359'); + 2015-03-05 09:00:00 """, since = "2.3.0") // scalastyle:on line.size.limit From 53c06ddabbdf689f8823807445849ad63173676f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 24 May 2018 13:00:24 -0700 Subject: [PATCH 562/593] [SPARK-24332][SS][MESOS] Fix places reading 'spark.network.timeout' as milliseconds ## What changes were proposed in this pull request? This PR replaces `getTimeAsMs` with `getTimeAsSeconds` to fix the issue that reading "spark.network.timeout" using a wrong time unit when the user doesn't specify a time out. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #21382 from zsxwing/fix-network-timeout-conf. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaRelation.scala | 4 +++- .../scala/org/apache/spark/sql/kafka010/KafkaSource.scala | 2 +- .../scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala | 2 +- .../cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala | 2 +- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 64ba98762788c..737da2e51b125 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -68,7 +68,7 @@ private[kafka010] class KafkaMicroBatchReader( private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", - SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 7103709969c18..c31e6ed3e0903 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -48,7 +48,9 @@ private[kafka010] class KafkaRelation( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sqlContext.sparkContext.conf.getTimeAsSeconds( + "spark.network.timeout", + "120s") * 1000L).toString ).toLong override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1c7b3a29a861f..101e649727fcf 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaSource( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L).toString ).toLong private val maxOffsetsPerTrigger = diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 81abc9860bfc3..3efc90fe466b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -65,7 +65,7 @@ private[spark] class KafkaRDD[K, V]( // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", - conf.getTimeAsMs("spark.network.timeout", "120s")) + conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 9b75e4c98344a..d35bea4aca311 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -634,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } From 0fd68cb7278e5fdf106e73b580ee7dd829006386 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:08:52 -0700 Subject: [PATCH 563/593] [SPARK-24234][SS] Support multiple row writers in continuous processing shuffle reader. ## What changes were proposed in this pull request? https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii Support multiple different row writers in continuous processing shuffle reader. Note that having multiple read-side buffers ended up being the natural way to do this. Otherwise it's hard to express the constraint of sending an epoch marker only when all writers have sent one. ## How was this patch tested? new unit tests Author: Jose Torres Closes #21385 from jose-torres/multipleWrite. --- .../shuffle/ContinuousShuffleReadRDD.scala | 21 ++- .../shuffle/UnsafeRowReceiver.scala | 87 ++++++++-- .../shuffle/ContinuousShuffleReadSuite.scala | 163 +++++++++++++++--- 3 files changed, 227 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 270b1a5c28dee..801b28b751bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -25,11 +25,16 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { +case class ContinuousShuffleReadPartition( + index: Int, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -42,16 +47,24 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. + * + * @param sc the RDD context + * @param numPartitions the number of read partitions for this RDD + * @param queueSize the size of the row buffers to use + * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD + * @param epochIntervalMs the checkpoint interval of the streaming query */ class ContinuousShuffleReadRDD( sc: SparkContext, numPartitions: Int, - queueSize: Int = 1024) + queueSize: Int = 1024, + numShuffleWriters: Int = 1, + epochIntervalMs: Long = 1000) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize) + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index b8adbb743c6c2..d81f552d56626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -27,10 +29,17 @@ import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * + * Each message comes tagged with writerId, identifying which writer the message is coming + * from. The receiver will only begin the next epoch once all writers have sent an epoch + * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable -private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -41,11 +50,15 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa */ private[shuffle] class UnsafeRowReceiver( queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + } // Exposed for testing to determine if the endpoint gets stopped on task end. private[shuffle] val stopped = new AtomicBoolean(false) @@ -56,20 +69,70 @@ private[shuffle] class UnsafeRowReceiver( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: UnsafeRowReceiverMessage => - queue.put(r) + queues(r.writerId).put(r) context.reply(()) } override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = queue.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null + // An array of flags for whether each writer ID has gotten an epoch marker. + private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + + private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { + override def call(): UnsafeRowReceiverMessage = queues(writerId).take() } - override def close(): Unit = {} + // Initialize by submitting tasks to read the first row from each writer. + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + /** + * In each call to getNext(), we pull the next row available in the completion queue, and then + * submit another task to read the next row from the writer which returned it. + * + * When a writer sends an epoch marker, we note that it's finished and don't submit another + * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. + */ + override def getNext(): UnsafeRow = { + var nextRow: UnsafeRow = null + while (!finished && nextRow == null) { + completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { + case null => + // Try again if the poll didn't wait long enough to get a real result. + // But we should be getting at least an epoch marker every checkpoint interval. + val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { + case (flag, idx) if !flag => idx + } + logWarning( + s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + + s"for writers $writerIdsUncommitted to send epoch markers.") + + // The completion service guarantees this future will be available immediately. + case future => future.get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + nextRow = r + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need to loop again to poll from the remaining + // writers. + writerEpochMarkersReceived(writerId) = true + if (writerEpochMarkersReceived.forall(_ == true)) { + finished = true + } + } + } + } + + nextRow + } + + override def close(): Unit = { + executor.shutdownNow() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index b25e75b3b37a6..2e4d607a403ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -21,7 +21,8 @@ import org.apache.spark.{TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String class ContinuousShuffleReadSuite extends StreamTest { @@ -30,6 +31,11 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { messages.foreach(endpoint.askSync[Unit](_)) } @@ -57,8 +63,8 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(111)) + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) ) ctx.markTaskCompleted(None) @@ -71,8 +77,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with marker last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -86,10 +95,10 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverRow(unsafeRow(111)), - ReceiverRow(unsafeRow(222)), - ReceiverRow(unsafeRow(333)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(111)), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) ) val iter = rdd.compute(rdd.partitions(0), ctx) @@ -101,11 +110,11 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverRow(unsafeRow(111)), - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(222)), - ReceiverRow(unsafeRow(333)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) ) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) @@ -118,14 +127,15 @@ class ContinuousShuffleReadSuite extends StreamTest { test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( endpoint, - ReceiverEpochMarker(), - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(111)), - ReceiverEpochMarker(), - ReceiverEpochMarker(), - ReceiverEpochMarker() + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0) ) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -146,8 +156,8 @@ class ContinuousShuffleReadSuite extends StreamTest { // Send index for identification. send( part.endpoint, - ReceiverRow(unsafeRow(part.index)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(part.index)), + ReceiverEpochMarker(0) ) } @@ -160,25 +170,122 @@ class ContinuousShuffleReadSuite extends StreamTest { } test("blocks waiting for new rows") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) + val epoch = rdd.compute(rdd.partitions(0), ctx) val readRowThread = new Thread { override def run(): Unit = { - // set the non-inheritable thread local - TaskContext.setTaskContext(ctx) - val epoch = rdd.compute(rdd.partitions(0), ctx) - epoch.next().getInt(0) + try { + epoch.next().getInt(0) + } catch { + case _: InterruptedException => // do nothing - expected at test ending + } } } try { readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.TIMED_WAITING) } } finally { readRowThread.interrupt() readRowThread.join() } } + + test("multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + + test("epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(2) + ) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + send(endpoint, ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarkerThread.isAlive + } + + // Join to pick up assertion failures. + readEpochMarkerThread.join() + } + + test("writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } } From 3b20b34ab72c92d9d20188ed430955e1a94eac9c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 25 May 2018 11:16:35 +0800 Subject: [PATCH 564/593] [SPARK-24367][SQL] Parquet: use JOB_SUMMARY_LEVEL instead of deprecated flag ENABLE_JOB_SUMMARY ## What changes were proposed in this pull request? In current parquet version,the conf ENABLE_JOB_SUMMARY is deprecated. When writing to Parquet files, the warning message ```WARN org.apache.parquet.hadoop.ParquetOutputFormat: Setting parquet.enable.summary-metadata is deprecated, please use parquet.summary.metadata.level``` keeps showing up. From https://github.com/apache/parquet-mr/blame/master/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetOutputFormat.java#L164 we can see that we should use JOB_SUMMARY_LEVEL. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21411 from gengliangwang/summaryLevel. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../datasources/parquet/ParquetFileFormat.scala | 10 ++++++---- .../datasources/parquet/ParquetCommitterSuite.scala | 7 ++++++- .../execution/datasources/parquet/ParquetIOSuite.scala | 2 +- .../parquet/ParquetPartitionDiscoverySuite.scala | 2 +- .../datasources/parquet/ParquetQuerySuite.scala | 2 +- .../sql/sources/ParquetHadoopFsRelationSuite.scala | 2 +- 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 15ba10f604510..93d356fef07af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -395,7 +395,7 @@ object SQLConf { .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + - "will never be created, irrespective of the value of parquet.enable.summary-metadata") + "will never be created, irrespective of the value of parquet.summary.metadata.level") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d1f9e11ed4225..60fc9ec7e1f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -34,6 +34,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType @@ -125,16 +126,17 @@ class ParquetFileFormat conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. - if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { - conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) } - if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { // output summary is requested, but the class is not a Parquet Committer logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + s" create job summaries. " + - s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") } new OutputWriterFactory { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index f3ecc5ced689f..4b2437803d645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -91,9 +91,14 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils summary: Boolean, check: Boolean): Option[FileStatus] = { var result: Option[FileStatus] = None + val summaryLevel = if (summary) { + "ALL" + } else { + "NONE" + } withSQLConf( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> summaryLevel) { withTempPath { dest => val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) val destPath = new Path(dest.toURI) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0b3e8ca060d87..002c42f23bd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -543,7 +543,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) - withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withSQLConf(ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" spark.range(1 << 16).selectExpr("(id % 4) AS i") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index e887c9734a8b8..9966ed94a8392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1014,7 +1014,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val path = dir.getCanonicalPath withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", "spark.sql.sources.commitProtocolClass" -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { spark.range(3).write.parquet(s"$path/p0=0/p1=0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2b1227faf48a0..dbf637783e6d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName, SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true", - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true" + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL" ) { testSchemaMerging(2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index dce5bb7ddba66..6858bbc441721 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -124,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { withTempPath { dir => From 64fad0b519cf35b8c0a0dec18dd3df9488a5ed25 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 24 May 2018 21:38:04 -0700 Subject: [PATCH 565/593] [SPARK-24244][SPARK-24368][SQL] Passing only required columns to the CSV parser ## What changes were proposed in this pull request? uniVocity parser allows to specify only required column names or indexes for [parsing](https://www.univocity.com/pages/parsers-tutorial) like: ``` // Here we select only the columns by their indexes. // The parser just skips the values in other columns parserSettings.selectIndexes(4, 0, 1); CsvParser parser = new CsvParser(parserSettings); ``` In this PR, I propose to extract indexes from required schema and pass them into the CSV parser. Benchmarks show the following improvements in parsing of 1000 columns: ``` Select 100 columns out of 1000: x1.76 Select 1 column out of 1000: x2 ``` **Note**: Comparing to current implementation, the changes can return different result for malformed rows in the `DROPMALFORMED` and `FAILFAST` modes if only subset of all columns is requested. To have previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## How was this patch tested? It was tested by new test which selects 3 columns out of 15, by existing tests and by new benchmarks. Author: Maxim Gekk Author: Maxim Gekk Closes #21415 from MaxGekk/csv-column-pruning2. --- docs/sql-programming-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVFileFormat.scala | 18 ++++++-- .../datasources/csv/CSVOptions.scala | 3 ++ .../datasources/csv/UnivocityParser.scala | 26 ++++++----- .../datasources/csv/CSVBenchmarks.scala | 46 +++++++++++++++++++ .../datasources/csv/CSVInferSchemaSuite.scala | 22 ++++----- .../execution/datasources/csv/CSVSuite.scala | 41 ++++++++++++++--- .../csv/UnivocityParserSuite.scala | 37 +++++++-------- 10 files changed, 152 insertions(+), 52 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f1ed316341b95..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 93d356fef07af..8d2320d8a6ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1307,6 +1307,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** @@ -1664,6 +1671,8 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 917f0cb221412..ac4580a0919ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -480,6 +480,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def csv(csvDataset: Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone) val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index e20977a4ec79f..21279d6daf7ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -41,8 +41,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) val csvDataSource = CSVDataSource(parsedOptions) csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -51,8 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -64,7 +68,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { dataSchema: StructType): OutputWriterFactory = { CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) csvOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -97,6 +104,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new CSVOptions( options, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1066d156acd74..7119189a4e131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -28,16 +28,19 @@ import org.apache.spark.sql.catalyst.util._ class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], + val columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { def this( parameters: Map[String, String], + columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String = "") = { this( CaseInsensitiveMap(parameters), + columnPruning, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index d442ba7e59c61..1a3dacb8398e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,7 +74,53 @@ object CSVBenchmarks { } } + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 81091 / 81692 0.0 81090.7 1.0X + Select 100 columns 30003 / 34448 0.0 30003.0 2.7X + Select one column 24792 / 24855 0.0 24792.0 3.3X + count() 24344 / 24642 0.0 24343.8 3.3X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..842251be92c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(NullType, "", options) == NullType) assert(CSVInferSchema.inferField(NullType, null, options) == NullType) assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) @@ -41,7 +41,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) @@ -60,21 +60,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } test("Timestamp field types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) } test("Boolean fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) } @@ -92,12 +92,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT") + options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) @@ -111,12 +111,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), "GMT") + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == @@ -134,7 +134,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { test("DoubleType should be infered when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", - "positiveInf" -> "inf"), "GMT") + "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2bac1a3e2d4c6..afe10bdc4de26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1383,4 +1385,29 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(ds, Seq(Row(""" "a" """))) } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + assert(idf.count() == 2) + checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index efbf73534bd19..458edb253fb33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParserSuite extends SparkFunSuite { - private val parser = - new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String], "GMT")) + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) private def assertNull(v: Any) = assert(v == null) @@ -38,7 +39,7 @@ class UnivocityParserSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } @@ -51,21 +52,21 @@ class UnivocityParserSuite extends SparkFunSuite { // Nullable field with nullValue option. types.foreach { t => // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = false, options = options) var message = intercept[RuntimeException] { @@ -81,7 +82,7 @@ class UnivocityParserSuite extends SparkFunSuite { // If nullValue is different with empty string, then, empty string should not be casted into // null. Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") val converter = parser.makeConverter("_1", StringType, nullable = b, options = options) assert(converter.apply("") == UTF8String.fromString("")) @@ -89,7 +90,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val exception = intercept[RuntimeException]{ parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") } @@ -97,7 +98,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) @@ -107,7 +108,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = @@ -116,7 +117,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), "GMT") + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) @@ -131,7 +132,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val types = Seq(DoubleType, FloatType) val input = Seq("10u000", "abc", "1 2/3") types.foreach { dt => @@ -145,7 +146,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, options = options ).apply("nn").asInstanceOf[Float] @@ -156,7 +157,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, options = options ).apply("-").asInstanceOf[Double] @@ -165,14 +166,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Float] @@ -181,14 +182,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Double] From fd315f5884c03c6dd21abca178897584dee83f1a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 25 May 2018 12:49:06 -0700 Subject: [PATCH 566/593] [MINOR] Add port SSL config in toString and scaladoc ## What changes were proposed in this pull request? SPARK-17874 introduced a new configuration to set the port where SSL services bind to. We missed to update the scaladoc and the `toString` method, though. The PR adds it in the missing places ## How was this patch tested? checked the `toString` output in the logs Author: Marco Gaido Closes #21429 from mgaido91/minor_ssl. --- core/src/main/scala/org/apache/spark/SSLOptions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 477b01968c6ef..04c38f12acc78 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -128,7 +128,7 @@ private[spark] case class SSLOptions( } /** Returns a string representation of this SSLOptions with all the passwords masked. */ - override def toString: String = s"SSLOptions{enabled=$enabled, " + + override def toString: String = s"SSLOptions{enabled=$enabled, port=$port, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" @@ -142,6 +142,7 @@ private[spark] object SSLOptions extends Logging { * * The following settings are allowed: * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].port` - the port where to bind the SSL server * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory * $ - `[ns].keyStorePassword` - a password to the key-store file * $ - `[ns].keyPassword` - a password to the private key From 1b1528a504febfadf6fe41fd72e657689da50525 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 25 May 2018 15:42:46 -0700 Subject: [PATCH 567/593] [SPARK-24366][SQL] Improving of error messages for type converting ## What changes were proposed in this pull request? Currently, users are getting the following error messages on type conversions: ``` scala.MatchError: test (of class java.lang.String) ``` The message doesn't give any clues to the users where in the schema the error happened. In this PR, I would like to improve the error message like: ``` The value (test) of the type (java.lang.String) cannot be converted to struct ``` ## How was this patch tested? Added tests for converting of wrong values to `struct`, `map`, `array`, `string` and `decimal`. Author: Maxim Gekk Closes #21410 from MaxGekk/type-conv-error. --- .../sql/catalyst/CatalystTypeConverters.scala | 16 +++++++ .../CatalystTypeConvertersSuite.scala | 45 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 474ec592201d9..9e9105a157abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -170,6 +170,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -206,6 +209,10 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + "cannot be converted to a map type with " + + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})") } } @@ -252,6 +259,9 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${structType.catalogString}") } override def toScala(row: InternalRow): Row = { @@ -276,6 +286,9 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the string type") } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString @@ -309,6 +322,9 @@ object CatalystTypeConverters { case d: JavaBigDecimal => Decimal(d) case d: JavaBigInteger => Decimal(d) case d: Decimal => d + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${dataType.catalogString}") } decimal.toPrecision(dataType.precision, dataType.scale) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f3702ec92b425..f99af9b84d959 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -94,4 +94,49 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) == doubleGenericArray) } + + test("converting a wrong value to the struct type") { + val structType = new StructType().add("f1", IntegerType) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(structType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to struct")) + } + + test("converting a wrong value to the map type") { + val mapType = MapType(StringType, IntegerType, false) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(mapType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to a map type with key " + + "type (string) and value type (int)")) + } + + test("converting a wrong value to the array type") { + val arrayType = ArrayType(IntegerType, true) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(arrayType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to an array of int")) + } + + test("converting a wrong value to the decimal type") { + val decimalType = DecimalType(10, 0) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(decimalType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to decimal(10,0)")) + } + + test("converting a wrong value to the string type") { + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(StringType)(0.1) + } + assert(exception.getMessage.contains("The value (0.1) of the type " + + "(java.lang.Double) cannot be converted to the string type")) + } } From ed1a65448f228776afe2e5c6b1ac4228d2ed2854 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 26 May 2018 20:26:00 +0800 Subject: [PATCH 568/593] [SPARK-19112][CORE][FOLLOW-UP] Add missing shortCompressionCodecNames to configuration. ## What changes were proposed in this pull request? Spark provides four codecs: `lz4`, `lzf`, `snappy`, and `zstd`. This pr add missing shortCompressionCodecNames to configuration. ## How was this patch tested? manually tested Author: Yuming Wang Closes #21431 from wangyum/SPARK-19112. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index fd2670cba2125..64af0e98a82f5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -910,8 +910,8 @@ Apart from these, the following properties are also available, and may be useful lz4 The codec used to compress internal data such as RDD partitions, event log, broadcast variables - and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, - and snappy. You can also use fully qualified class names to specify the codec, + and shuffle outputs. By default, Spark provides four codecs: lz4, lzf, + snappy, and zstd. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, From d440699192f21b14dfb8ec0dc5673537e1003b55 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Sat, 26 May 2018 20:42:23 -0700 Subject: [PATCH 569/593] [SPARK-24381][TESTING] Add unit tests for NOT IN subquery around null values ## What changes were proposed in this pull request? This PR adds several unit tests along the `cols NOT IN (subquery)` pathway. There are a scattering of tests here and there which cover this codepath, but there doesn't seem to be a unified unit test of the correctness of null-aware anti joins anywhere. I have also added a brief explanation of how this expression behaves in SubquerySuite. Lastly, I made some clarifying changes in the NOT IN pathway in RewritePredicateSubquery. ## How was this patch tested? Added unit tests! There should be no behavioral change in this PR. Author: Miles Yucht Closes #21425 from mgyucht/spark-24381. --- .../sql/catalyst/optimizer/subquery.scala | 9 +- ...not-in-unit-tests-multi-column-literal.sql | 39 +++++ .../not-in-unit-tests-multi-column.sql | 98 ++++++++++++ ...ot-in-unit-tests-single-column-literal.sql | 42 +++++ .../not-in-unit-tests-single-column.sql | 123 +++++++++++++++ ...in-unit-tests-multi-column-literal.sql.out | 54 +++++++ .../not-in-unit-tests-multi-column.sql.out | 134 ++++++++++++++++ ...n-unit-tests-single-column-literal.sql.out | 69 ++++++++ .../not-in-unit-tests-single-column.sql.out | 149 ++++++++++++++++++ .../typeCoercion/native/concat.sql.out | 2 +- 10 files changed, 714 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7d..de89e17e51f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -116,15 +116,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // (a1,a2,...) = (b1,b2,...) // to // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... - val joinConds = splitConjunctivePredicates(joinCond.get) + val baseJoinConds = splitConjunctivePredicates(joinCond.get) + val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c))) // After that, add back the correlated join predicate(s) in the subquery // Example: // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) // will have the final conditions in the LEFT ANTI as - // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) - val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql new file mode 100644 index 0000000000000..8eea84f4f5272 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -0,0 +1,39 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- This file has the same test cases as not-in-unit-tests-multi-column.sql with literals instead of +-- subqueries. Small changes have been made to the literals to make them typecheck. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +-- Case 1 (not possible to write a literal with no rows, so we ignore it.) +-- (subquery is empty -> row is returned) + +-- Cases 2, 3 and 4 are currently broken, so I have commented them out here. +-- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql new file mode 100644 index 0000000000000..9f8dc7fca3b94 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql @@ -0,0 +1,98 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- +-- Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? | +-- | 1 | empty | * | * | * | yes | +-- | 2 | 1+ row has null for all columns | * | * | * | no | +-- | 3 | no row has null for all columns | (yes, yes) | * | * | no | +-- | 4 | no row has null for all columns | (no, yes) | yes | * | no | +-- | 5 | no row has null for all columns | (no, yes) | no | * | yes | +-- | 6 | no | (no, no) | yes | yes | no | +-- | 7 | no | (no, no) | _ | _ | yes | +-- +-- This can be generalized to include more tests for more columns, but it covers the main cases +-- when there is more than one column. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d); + + -- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +; + + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +; + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql new file mode 100644 index 0000000000000..b261363d1dde7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql @@ -0,0 +1,42 @@ +-- Unit tests for simple NOT IN with a literal expression of a single column +-- +-- More information can be found in not-in-unit-tests-single-column.sql. +-- This file has the same test cases as not-in-unit-tests-single-column.sql with literals instead of +-- subqueries. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null); + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2); + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2); + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql new file mode 100644 index 0000000000000..2cc08e10acf67 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql @@ -0,0 +1,123 @@ +-- Unit tests for simple NOT IN predicate subquery across a single column. +-- +-- ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the +-- rules are confusing to the uninitiated, and precedence and treatment of null values is plain +-- unintuitive. To make this simpler to understand, I've come up with a plain English way of +-- describing the expected behavior of this query. +-- +-- - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of +-- whether the filtered columns include nulls. +-- - If the subquery contains a result with all columns null, then the row should not be returned. +-- - If for all non-null filter columns there exists a row in the subquery in which each column +-- either +-- 1. is equal to the corresponding filter column or +-- 2. is null +-- then the row should not be returned. (This includes the case where all filter columns are +-- null.) +-- - Otherwise, the row should be returned. +-- +-- Using these rules, we can come up with a set of test cases for single-column and multi-column +-- NOT IN test cases. +-- +-- Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | is a null? | a = c? | row with a included in result? | +-- | 1 | empty | | | yes | +-- | 2 | yes | | | no | +-- | 3 | no | yes | | no | +-- | 4 | no | no | yes | no | +-- | 5 | no | no | no | yes | +-- +-- There are also some considerations around correlated subqueries. Correlated subqueries can +-- cause cases 2, 3, or 4 to be reduced to case 1 by limiting the number of rows returned by the +-- subquery, so the row from the parent table should always be included in the output. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +; + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +; + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +; + + -- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out new file mode 100644 index 0000000000000..a16e98af9a417 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 1 schema +struct +-- !query 1 output +NULL 1 + + +-- !query 2 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 3 schema +struct +-- !query 3 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out new file mode 100644 index 0000000000000..aa5f64b8ebf55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out @@ -0,0 +1,134 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 +NULL NULL + + +-- !query 3 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 6 schema +struct +-- !query 6 output +NULL 1 + + +-- !query 7 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 8 schema +struct +-- !query 8 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out new file mode 100644 index 0000000000000..446447e890449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null) +-- !query 1 schema +struct +-- !query 1 output + + + +-- !query 2 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6) +-- !query 4 schema +struct +-- !query 4 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out new file mode 100644 index 0000000000000..f58ebeacc2872 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out @@ -0,0 +1,149 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 + + +-- !query 3 +-- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +-- !query 6 schema +struct +-- !query 6 output +2 3 + + +-- !query 7 +-- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 7 schema +struct +-- !query 7 output +2 3 +4 5 +NULL 1 + + +-- !query 8 +-- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 8 schema +struct +-- !query 8 output +NULL 1 + + +-- !query 9 +-- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 9 schema +struct +-- !query 9 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 62befc5ca0f15..be637b66abc86 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 14 -- !query 0 From 672209f2909a95e891f3c779bfb2f0e534239851 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 28 May 2018 10:50:17 +0800 Subject: [PATCH 570/593] [SPARK-24334] Fix race condition in ArrowPythonRunner causes unclean shutdown of Arrow memory allocator ## What changes were proposed in this pull request? There is a race condition of closing Arrow VectorSchemaRoot and Allocator in the writer thread of ArrowPythonRunner. The race results in memory leak exception when closing the allocator. This patch removes the closing routine from the TaskCompletionListener and make the writer thread responsible for cleaning up the Arrow memory. This issue be reproduced by this test: ``` def test_memory_leak(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType, array, lit, explode # Have all data in a single executor thread so it can trigger the race condition easier with self.sql_conf({'spark.sql.shuffle.partitions': 1}): df = self.spark.range(0, 1000) df = df.withColumn('id', array([lit(i) for i in range(0, 300)])) \ .withColumn('id', explode(col('id'))) \ .withColumn('v', array([lit(i) for i in range(0, 1000)])) pandas_udf(df.schema, PandasUDFType.GROUPED_MAP) def foo(pdf): xxx return pdf result = df.groupby('id').apply(foo) with QuietTest(self.sc): with self.assertRaises(py4j.protocol.Py4JJavaError) as context: result.count() self.assertTrue('Memory leaked' not in str(context.exception)) ``` Note: Because of the race condition, the test case cannot reproduce the issue reliably so it's not added to test cases. ## How was this patch tested? Because of the race condition, the bug cannot be unit test easily. So far it has only happens on large amount of data. This is currently tested manually. Author: Li Jin Closes #21397 from icexelloss/SPARK-24334-arrow-memory-leak. --- .../execution/python/ArrowPythonRunner.scala | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5fcdcddca7d51..01e19bddbfb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -70,19 +70,13 @@ class ArrowPythonRunner( val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - context.addTaskCompletionListener { _ => - root.close() - allocator.close() - } - - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + while (inputIterator.hasNext) { val nextBatch = inputIterator.next() @@ -94,8 +88,21 @@ class ArrowPythonRunner( writer.writeBatch() arrowWriter.reset() } - } { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. root.close() allocator.close() } From de01a8d50c9c3e196591db057d544f5d7b24d95f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 28 May 2018 12:09:44 +0800 Subject: [PATCH 571/593] [SPARK-24373][SQL] Add AnalysisBarrier to RelationalGroupedDataset's and KeyValueGroupedDataset's child ## What changes were proposed in this pull request? When we create a `RelationalGroupedDataset` or a `KeyValueGroupedDataset` we set its child to the `logicalPlan` of the `DataFrame` we need to aggregate. Since the `logicalPlan` is already analyzed, we should not analyze it again. But this happens when the new plan of the aggregate is analyzed. The current behavior in most of the cases is likely to produce no harm, but in other cases re-analyzing an analyzed plan can change it, since the analysis is not idempotent. This can cause issues like the one described in the JIRA (missing to find a cached plan). The PR adds an `AnalysisBarrier` to the `logicalPlan` which is used as child of `RelationalGroupedDataset` or a `KeyValueGroupedDataset`. ## How was this patch tested? added UT Author: Marco Gaido Closes #21432 from mgaido91/SPARK-24373. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 12 +-- .../spark/sql/GroupedDatasetSuite.scala | 96 +++++++++++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 32267eb0300f5..abb5ae53f4d73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,7 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cbd..36f6038aa9485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = queryExecution.analyzed + private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c2be3610ae30..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -63,17 +63,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) } } @@ -433,7 +433,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.logicalPlan)) + df.planWithBarrier)) } /** @@ -459,7 +459,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.logicalPlan + val child = df.planWithBarrier val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala new file mode 100644 index 0000000000000..147c0b61f5017 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +class GroupedDatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val scalaUDF = udf((x: Long) => { x + 1 }) + private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) + + private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { + assert(atLevel >= 0) + var children = Seq(ds.queryExecution.logical) + (1 to atLevel).foreach { _ => + children = children.flatMap(_.children) + } + val barriers = children.collect { + case ab: AnalysisBarrier => ab + } + assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + + ds.queryExecution.logical) + } + + test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { + val groupByDataset = datasetWithUDF.groupBy() + val rollupDataset = datasetWithUDF.rollup("s") + val cubeDataset = datasetWithUDF.cube("s") + val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) + datasetWithUDF.cache() + Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => + val df = rgDS.count() + assertContainsAnalysisBarrier(df) + assertCached(df) + } + + val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( + Array.emptyByteArray, + Array.emptyByteArray, + Array.empty, + StructType(Seq(StructField("s", LongType)))) + val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => + assertContainsAnalysisBarrier(df, 2) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } + + test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { + val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) + datasetWithUDF.cache() + val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) + val keysKVDataset = kvDasaset.keys + val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) + val aggKVDataset = kvDasaset.count() + val otherKVDataset = spark.range(1).groupByKey(_ + 1) + val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) + Seq((mapValuesKVDataset, 1), + (keysKVDataset, 2), + (flatMapGroupsKVDataset, 2), + (aggKVDataset, 1), + (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => + assertContainsAnalysisBarrier(df, analysisBarrierDepth) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } +} From fa2ae9d2019f839647d17932d8fea769e7622777 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 28 May 2018 12:56:05 +0800 Subject: [PATCH 572/593] [SPARK-24392][PYTHON] Label pandas_udf as Experimental ## What changes were proposed in this pull request? The pandas_udf functionality was introduced in 2.3.0, but is not completely stable and still evolving. This adds a label to indicate it is still an experimental API. ## How was this patch tested? NA Author: Bryan Cutler Closes #21435 from BryanCutler/arrow-pandas_udf-experimental-SPARK-24392. --- docs/sql-programming-guide.md | 4 ++++ python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/functions.py | 2 ++ python/pyspark/sql/group.py | 2 ++ python/pyspark/sql/session.py | 2 ++ 5 files changed, 12 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fc26562ff33da..50600861912b1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1827,6 +1827,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. +## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above + + - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 213dc158f9328..808235ab25440 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1975,6 +1975,8 @@ def toPandas(self): .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fbc8a2d038f8f..efcce25a08e04 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2456,6 +2456,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. + .. note:: Experimental + The function type of the UDF can be one of the following: 1. SCALAR diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 3505065b648f2..0906c9c6b329a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -236,6 +236,8 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. + .. note:: Experimental + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 13d6e2e53dbd0..d675a240172a7 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -584,6 +584,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr .. versionchanged:: 2.1 Added verifySchema. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] From b31b587cd091010337378cf448fd598c37757053 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 29 May 2018 10:35:30 +0800 Subject: [PATCH 573/593] [SPARK-19613][SS][TEST] Random.nextString is not safe for directory namePrefix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `Random.nextString` is good for generating random string data, but it's not proper for directory name prefix in `Utils.createDirectory(tempDir, Random.nextString(10))`. This PR uses more safe directory namePrefix. ```scala scala> scala.util.Random.nextString(10) res0: String = 馨쭔ᎰႻ穚䃈兩㻞藑並 ``` ```scala StateStoreRDDSuite: - versioning and immutability - recovering from files - usage with iterators - only gets and only puts - preferred locations using StateStoreCoordinator *** FAILED *** java.io.IOException: Failed to create a temp directory (under /.../spark/sql/core/target/tmp/StateStoreRDDSuite8712796397908632676) after 10 attempts! at org.apache.spark.util.Utils$.createDirectory(Utils.scala:295) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13$$anonfun$apply$6.apply(StateStoreRDDSuite.scala:152) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13$$anonfun$apply$6.apply(StateStoreRDDSuite.scala:149) at org.apache.spark.sql.catalyst.util.package$.quietly(package.scala:42) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13.apply(StateStoreRDDSuite.scala:149) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13.apply(StateStoreRDDSuite.scala:149) ... - distributed test *** FAILED *** java.io.IOException: Failed to create a temp directory (under /.../spark/sql/core/target/tmp/StateStoreRDDSuite8712796397908632676) after 10 attempts! at org.apache.spark.util.Utils$.createDirectory(Utils.scala:295) ``` ## How was this patch tested? Pass the existing tests.StateStoreRDDSuite: Author: Dongjoon Hyun Closes #21446 from dongjoon-hyun/SPARK-19613. --- .../execution/streaming/state/StateStoreRDDSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 65b39f0fbd73d..579a364ebc3e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -55,7 +55,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) @@ -73,7 +73,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString def makeStoreRDD( spark: SparkSession, @@ -101,7 +101,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("usage with iterators - only gets and only puts") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 // Returns an iterator of the incremented value made into the store @@ -149,7 +149,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn quietly { val queryRunId = UUID.randomUUID val opId = 0 - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext @@ -189,7 +189,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) .getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) From 2ced6193b39dc63e5f74138859f2a9d69d3cfd11 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 29 May 2018 10:48:48 +0800 Subject: [PATCH 574/593] [SPARK-24377][SPARK SUBMIT] make --py-files work in non pyspark application ## What changes were proposed in this pull request? For some Spark applications, though they're a java program, they require not only jar dependencies, but also python dependencies. One example is Livy remote SparkContext application, this application is actually an embedded REPL for Scala/Python/R, it will not only load in jar dependencies, but also python and R deps, so we should specify not only "--jars", but also "--py-files". Currently for a Spark application, --py-files can only be worked for a pyspark application, so it will not be worked in the above case. So here propose to remove such restriction. Also we tested that "spark.submit.pyFiles" only supports quite limited scenario (client mode with local deps), so here also expand the usage of "spark.submit.pyFiles" to be alternative of --py-files. ## How was this patch tested? UT added. Author: jerryshao Closes #21420 from jerryshao/SPARK-24377. --- .../org/apache/spark/deploy/SparkSubmit.scala | 11 ++--- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../spark/deploy/SparkSubmitSuite.scala | 46 ++++++++++++++++++- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4baf032f0e9c6..a46af26feb061 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -430,18 +430,15 @@ private[spark] class SparkSubmit extends Logging { // Usage: PythonAppRunner
    [app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs - if (clusterManager != YARN) { - // The YARN backend distributes the primary file differently, so don't merge it. - args.files = mergeFileLists(args.files, args.primaryResource) - } } if (clusterManager != YARN) { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (localPyFiles != null) { - sparkConf.set("spark.submit.pyFiles", localPyFiles) - } + } + + if (localPyFiles != null) { + sparkConf.set("spark.submit.pyFiles", localPyFiles) } // In YARN mode for an R app, add the SparkR package archive and the R package diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fed4e0a5069c3..fb232101114b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -182,6 +182,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull + pyFiles = Option(pyFiles).orElse(sparkProperties.get("spark.submit.pyFiles")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull @@ -280,9 +281,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } - if (pyFiles != null && !isPython) { - error("--py-files given but primary resource is not a Python script") - } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 43286953e4383..545c8d0423dc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -771,9 +771,13 @@ class SparkSubmitSuite PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val pyFile1 = File.createTempFile("file1", ".py", tmpDir) + val pyFile2 = File.createTempFile("file2", ".py", tmpDir) val writer4 = new PrintWriter(f4) - val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" writer4.println("spark.submit.pyFiles " + remotePyFiles) writer4.close() val clArgs4 = Seq( @@ -783,7 +787,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -1093,6 +1097,44 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + test("support --py-files/spark.submit.pyFiles in non pyspark application") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + + conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.submit.pyFiles") should (startWith("/")) + + // Verify "spark.submit.pyFiles" + val args1 = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs1 = new SparkSubmitArguments(args1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) + + conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf1.get("spark.submit.pyFiles") should (startWith("/")) + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { From 23db600c956cff4d0b20c38ddd2d746de2b535a0 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 28 May 2018 23:23:22 -0700 Subject: [PATCH 575/593] [SPARK-24250][SQL][FOLLOW-UP] support accessing SQLConf inside tasks ## What changes were proposed in this pull request? We should not stop users from calling `getActiveSession` and `getDefaultSession` in executors. To not break the existing behaviors, we should simply return None. ## How was this patch tested? N/A Author: Xiao Li Closes #21436 from gatorsmile/followUpSPARK-24250. --- .../org/apache/spark/sql/SparkSession.scala | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..565042fcf762e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1021,21 +1021,33 @@ object SparkSession extends Logging { /** * Returns the active SparkSession for the current thread, returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(activeThreadSession.get) + } } /** * Returns the default SparkSession that is returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(defaultSession.get) + } } /** From aca65c63cb12073eb193fe08998994c60acb8b58 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 29 May 2018 20:10:59 +0800 Subject: [PATCH 576/593] [SPARK-23991][DSTREAMS] Fix data loss when WAL write fails in allocateBlocksToBatch When blocks tried to get allocated to a batch and WAL write fails then the blocks will be removed from the received block queue. This fact simply produces data loss because the next allocation will not find the mentioned blocks in the queue. In this PR blocks will be removed from the received queue only if WAL write succeded. Additional unit test. Author: Gabor Somogyi Closes #21430 from gaborgsomogyi/SPARK-23991. Change-Id: I5ead84f0233f0c95e6d9f2854ac2ff6906f6b341 --- .../scheduler/ReceivedBlockTracker.scala | 3 +- .../streaming/ReceivedBlockTrackerSuite.scala | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index dacff69d55dd2..cf4324578ea87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -112,10 +112,11 @@ private[streaming] class ReceivedBlockTracker( def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { val streamIdToBlocks = streamIds.map { streamId => - (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + (streamId, getReceivedBlockQueue(streamId).clone()) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + streamIds.foreach(getReceivedBlockQueue(_).clear()) timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 4fa236bd39663..fd7e00b1de25f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -26,10 +26,12 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration +import org.mockito.Matchers.any +import org.mockito.Mockito.{doThrow, reset, spy} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -115,6 +117,47 @@ class ReceivedBlockTrackerSuite tracker2.stop() } + test("block allocation to batch should not loose blocks from received queue") { + val tracker1 = spy(createTracker()) + tracker1.isWriteAheadLogEnabled should be (true) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + // Add blocks + val blockInfos = generateBlockInfos() + blockInfos.map(tracker1.addBlock) + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Try to allocate the blocks to a batch and verify that it's failing + // The blocks should stay in the received queue when WAL write failing + doThrow(new RuntimeException("Not able to write BatchAllocationEvent")) + .when(tracker1).writeToLog(any(classOf[BatchAllocationEvent])) + val errMsg = intercept[RuntimeException] { + tracker1.allocateBlocksToBatch(1) + } + assert(errMsg.getMessage === "Not able to write BatchAllocationEvent") + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + tracker1.getBlocksOfBatch(1) shouldEqual Map.empty + tracker1.getBlocksOfBatchAndStream(1, streamId) shouldEqual Seq.empty + + // Allocate the blocks to a batch and verify that all of them have been allocated + reset(tracker1) + tracker1.allocateBlocksToBatch(2) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker1.hasUnallocatedReceivedBlocks should be (false) + tracker1.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker1.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + + tracker1.stop() + + // Recover from WAL to see the correctness + val tracker2 = createTracker(recoverFromWriteAheadLog = true) + tracker2.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker2.hasUnallocatedReceivedBlocks should be (false) + tracker2.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker2.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates @@ -312,7 +355,7 @@ class ReceivedBlockTrackerSuite recoverFromWriteAheadLog: Boolean = false, clock: Clock = new SystemClock): ReceivedBlockTracker = { val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None - val tracker = new ReceivedBlockTracker( + var tracker = new ReceivedBlockTracker( conf, hadoopConf, Seq(streamId), clock, recoverFromWriteAheadLog, cpDirOption) allReceivedBlockTrackers += tracker tracker From 900bc1f7dc5a7c013b473fceab1c4052ade74a2f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 29 May 2018 10:22:18 -0700 Subject: [PATCH 577/593] [SPARK-24371][SQL] Added isInCollection in DataFrame API for Scala and Java. ## What changes were proposed in this pull request? Implemented **`isInCollection `** in DataFrame API for both Scala and Java, so users can do ```scala val profileDF = Seq( Some(1), Some(2), Some(3), Some(4), Some(5), Some(6), Some(7), None ).toDF("profileID") val validUsers: Seq[Any] = Seq(6, 7.toShort, 8L, "3") val result = profileDF.withColumn("isValid", $"profileID". isInCollection(validUsers)) result.show(10) """ +---------+-------+ |profileID|isValid| +---------+-------+ | 1| false| | 2| false| | 3| true| | 4| false| | 5| false| | 6| true| | 7| true| | null| null| +---------+-------+ """.stripMargin ``` ## How was this patch tested? Several unit tests are added. Author: DB Tsai Closes #21416 from dbtsai/optimize-set. --- .../sql/catalyst/optimizer/expressions.scala | 1 - .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 19 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 62 ++++++++++++++++++- 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e487693927ab6..c486ad700f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..b3e59f53ee3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging { @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group expr_ops + * @since 2.4.0 + */ + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Scala Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Java Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { From f48938800e6dc3880441f160dd93856b9f86874e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 30 May 2018 09:32:33 +0800 Subject: [PATCH 578/593] [SPARK-24365][SQL] Add Data Source write benchmark ## What changes were proposed in this pull request? Add Data Source write benchmark. So that it would be easier to measure the writer performance. Author: Gengliang Wang Closes #21409 from gengliangwang/parquetWriteBenchmark. --- .../benchmark/DataSourceWriteBenchmark.scala | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2d2cdebd067c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure data source write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: + * spark-submit --class + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + */ +object DataSourceWriteBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceWriteBenchmark") + .setIfMissing("spark.master", "local[1]") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.orc.compression.codec", "snappy") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + val tempTable = "temp" + val numRows = 1024 * 1024 * 15 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = { + spark.sql(s"create table $table(id $dataType) using $format") + benchmark.addCase(s"Output Single $dataType Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable") + } + } + + def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format") + benchmark.addCase("Output Int and String Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS STRING) AS c2 FROM $tempTable") + } + } + + def writePartition(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)") + benchmark.addCase("Output Partitions") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," + + s" CAST(id % 2 AS INT) AS p FROM $tempTable") + } + } + + def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS") + benchmark.addCase("Output Buckets") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS INT) AS c2 FROM $tempTable") + } + } + + def main(args: Array[String]): Unit = { + val tableInt = "tableInt" + val tableDouble = "tableDouble" + val tableIntString = "tableIntString" + val tablePartition = "tablePartition" + val tableBucket = "tableBucket" + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + withTempTable(tempTable) { + spark.range(numRows).createOrReplaceTempView(tempTable) + formats.foreach { format => + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() + } + } + } + } +} From a4be981c0476bb613c660b70a370f671c8b3ffee Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 29 May 2018 23:26:39 -0700 Subject: [PATCH 579/593] [SPARK-24331][SPARKR][SQL] Adding arrays_overlap, array_repeat, map_entries to SparkR ## What changes were proposed in this pull request? The PR adds functions `arrays_overlap`, `array_repeat`, `map_entries` to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Examples ### arrays_overlap ``` df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), list(list(1L, 2L), list(3L, 4L)), list(list(1L, NA), list(3L, 4L)))) collect(select(df, arrays_overlap(df[[1]], df[[2]]))) ``` ``` arrays_overlap(_1, _2) 1 TRUE 2 FALSE 3 NA ``` ### array_repeat ``` df <- createDataFrame(list(list("a", 3L), list("b", 2L))) collect(select(df, array_repeat(df[[1]], df[[2]]))) ``` ``` array_repeat(_1, _2) 1 a, a, a 2 b, b ``` ``` collect(select(df, array_repeat(df[[1]], 2L))) ``` ``` array_repeat(_1, 2) 1 a, a 2 b, b ``` ### map_entries ``` df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) collect(select(df, map_entries(df$map))) ``` ``` map_entries(map) 1 x, 1, y, 2 ``` Author: Marek Novotny Closes #21434 from mn-mikke/SPARK-24331. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/DataFrame.R | 2 + R/pkg/R/functions.R | 58 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 12 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 22 +++++++++- 5 files changed, 91 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c575fe255f57a..73a33af4dd48b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,7 +204,9 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_repeat", "array_sort", + "arrays_overlap", "asc", "ascii", "asin", @@ -302,6 +304,7 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", "map_keys", "map_values", "max", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index fcb3521f901ea..abc91aeeb4825 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,7 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. #' @param value A value to compute on. #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. @@ -207,7 +208,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -216,11 +217,10 @@ NULL #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) -#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) -#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL @@ -3048,6 +3048,26 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + #' @details #' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array #' must be orderable. NA elements will be placed at the end of the returned array. @@ -3062,6 +3082,21 @@ setMethod("array_sort", column(jc) }) +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3076,6 +3111,19 @@ setMethod("flatten", column(jc) }) +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3ea181157b644..8894cb1c5b92f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,10 +769,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1034,6 +1042,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 13b55ac6e6e3c..16c1fd5a065eb 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,21 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1531,8 +1546,13 @@ test_that("column functions", { result <- collect(select(df, flatten(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) - # Test map_keys(), map_values() and element_at() + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) From 0ebb0c0d4dd3e192464dc5e0e6f01efa55b945ed Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Wed, 30 May 2018 18:11:33 +0800 Subject: [PATCH 580/593] [SPARK-23754][PYTHON] Re-raising StopIteration in client code ## What changes were proposed in this pull request? Make sure that `StopIteration`s raised in users' code do not silently interrupt processing by spark, but are raised as exceptions to the users. The users' functions are wrapped in `safe_iter` (in `shuffle.py`), which re-raises `StopIteration`s as `RuntimeError`s ## How was this patch tested? Unit tests, making sure that the exceptions are indeed raised. I am not sure how to check whether a `Py4JJavaError` contains my exception, so I simply looked for the exception message in the java exception's `toString`. Can you propose a better way? ## License This is my original work, licensed in the same way as spark Author: e-dorigatti Author: edorigatti Closes #21383 from e-dorigatti/fix_spark_23754. --- python/pyspark/rdd.py | 18 ++++++++++--- python/pyspark/shuffle.py | 7 ++--- python/pyspark/sql/tests.py | 16 +++++++++++ python/pyspark/sql/udf.py | 14 ++++++++-- python/pyspark/tests.py | 53 +++++++++++++++++++++++++++++++++++++ python/pyspark/util.py | 28 +++++++++++++++++--- 6 files changed, 125 insertions(+), 11 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d5a237a5b2855..14d9128502ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -417,7 +418,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -798,6 +799,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -847,6 +850,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -918,6 +923,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -950,6 +957,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1643,6 +1653,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c7bd8f01b907f..a2450932e303d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,22 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + func = fail_on_stopiteration(self.func) + + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + func._argspec = _get_argspec(self.func) + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..3b37cc028c1b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1246,6 +1277,28 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,11 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - if sys.version_info[0] < 3: + + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here + argspec = f._argspec + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9e7bad0edd9f6c59c0af21c95e5df98cf82150d3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 30 May 2018 05:18:18 -0700 Subject: [PATCH 581/593] [SPARK-24419][BUILD] Upgrade SBT to 0.13.17 with Scala 2.10.7 for JDK9+ ## What changes were proposed in this pull request? Upgrade SBT to 0.13.17 with Scala 2.10.7 for JDK9+ ## How was this patch tested? Existing tests Author: DB Tsai Closes #21458 from dbtsai/sbt. --- project/build.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/build.properties b/project/build.properties index b19518fd7aa1c..d03985d980ec8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.16 +sbt.version=0.13.17 From 1e46f92f956a00d04d47340489b6125d44dbd47b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 31 May 2018 00:23:25 +0800 Subject: [PATCH 582/593] [SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set ## What changes were proposed in this pull request? This pr fixed an issue when having multiple distinct aggregations having the same argument set, e.g., ``` scala>: paste val df = sql( s"""SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) | FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) """.stripMargin) java.lang.RuntimeException You hit a query analyzer bug. Please report your query to Spark user mailing list. ``` The root cause is that `RewriteDistinctAggregates` can't detect multiple distinct aggregations if they have the same argument set. This pr modified code so that `RewriteDistinctAggregates` could count the number of aggregate expressions with `isDistinct=true`. ## How was this patch tested? Added tests in `DataFrameAggregateSuite`. Author: Takeshi Yamamuro Closes #21443 from maropu/SPARK-24369. --- .../optimizer/RewriteDistinctAggregates.scala | 7 ++++--- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/resources/sql-tests/inputs/group-by.sql | 6 +++++- .../test/resources/sql-tests/results/group-by.sql.out | 11 ++++++++++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 4448ace7105a4..bc898ab0dc723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,7 +115,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => + val distincgAggExpressions = aggExpressions.filter(_.isDistinct) + val distinctAggGroups = distincgAggExpressions.groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -132,7 +133,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size > 1) { + if (distincgAggExpressions.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -151,7 +152,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..b9452b58657a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.partition(_.isDistinct) if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 From b142157dcc7f595eea93d66dda8b1d169a38d95c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 30 May 2018 10:33:34 -0700 Subject: [PATCH 583/593] [SPARK-24384][PYTHON][SPARK SUBMIT] Add .py files correctly into PythonRunner in submit with client mode in spark-submit ## What changes were proposed in this pull request? In client side before context initialization specifically, .py file doesn't work in client side before context initialization when the application is a Python file. See below: ``` $ cat /home/spark/tmp.py def testtest(): return 1 ``` This works: ``` $ cat app.py import pyspark pyspark.sql.SparkSession.builder.getOrCreate() import tmp print("************************%s" % tmp.testtest()) $ ./bin/spark-submit --master yarn --deploy-mode client --py-files /home/spark/tmp.py app.py ... ************************1 ``` but this doesn't: ``` $ cat app.py import pyspark import tmp pyspark.sql.SparkSession.builder.getOrCreate() print("************************%s" % tmp.testtest()) $ ./bin/spark-submit --master yarn --deploy-mode client --py-files /home/spark/tmp.py app.py Traceback (most recent call last): File "/home/spark/spark/app.py", line 2, in import tmp ImportError: No module named tmp ``` ### How did it happen? In client mode specifically, the paths are being added into PythonRunner as are: https://github.com/apache/spark/blob/628c7b517969c4a7ccb26ea67ab3dd61266073ca/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L430 https://github.com/apache/spark/blob/628c7b517969c4a7ccb26ea67ab3dd61266073ca/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala#L49-L88 The problem here is, .py file shouldn't be added as are since `PYTHONPATH` expects a directory or an archive like zip or egg. ### How does this PR fix? We shouldn't simply just add its parent directory because other files in the parent directory could also be added into the `PYTHONPATH` in client mode before context initialization. Therefore, we copy .py files into a temp directory for .py files and add it to `PYTHONPATH`. ## How was this patch tested? Unit tests are added and manually tested in both standalond and yarn client modes with submit. Author: hyukjinkwon Closes #21426 from HyukjinKwon/SPARK-24384. --- .../apache/spark/deploy/PythonRunner.scala | 29 ++++++++++++++++++- .../spark/deploy/yarn/YarnClusterSuite.scala | 15 ++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 1b7e031ee0678..ccb30e205ca40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io.File import java.net.{InetAddress, URI} +import java.nio.file.Files import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -48,7 +49,7 @@ object PythonRunner { // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) - val formattedPyFiles = formatPaths(pyFiles) + val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such @@ -153,4 +154,30 @@ object PythonRunner { .map { p => formatPath(p, testWindows) } } + /** + * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does + * not expect a file. This method creates a temporary directory and puts the ".py" files + * if exist in the given paths. + */ + private def resolvePyFiles(pyFiles: Array[String]): Array[String] = { + lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles") + pyFiles.flatMap { pyFile => + // In case of client with submit, the python paths should be set before context + // initialization because the context initialization can be done later. + // We will copy the local ".py" files because ".py" file shouldn't be added + // alone but its parent directory in PYTHONPATH. See SPARK-24384. + if (pyFile.endsWith(".py")) { + val source = new File(pyFile) + if (source.exists() && source.isFile && source.canRead) { + Files.copy(source.toPath, new File(dest, source.getName).toPath) + Some(dest.getAbsolutePath) + } else { + // Don't have to add it if it doesn't exist or isn't readable. + None + } + } else { + Some(pyFile) + } + }.distinct + } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59b0f29e37d84..3b78b88de778d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -271,16 +271,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv - val moduleDir = - if (clientMode) { - // In client-mode, .py files added with --py-files are not visible in the driver. - // This is something that the launcher library would have to handle. - tempDir - } else { - val subdir = new File(tempDir, "pyModules") - subdir.mkdir() - subdir - } + val moduleDir = { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } val pyModule = new File(moduleDir, "mod1.py") Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) From ec6f971dc57bcdc0ad65ac1987b6f0c1801157f4 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 30 May 2018 11:04:09 -0700 Subject: [PATCH 584/593] [SPARK-23161][PYSPARK][ML] Add missing APIs to Python GBTClassifier ## What changes were proposed in this pull request? Add featureSubsetStrategy in GBTClassifier and GBTRegressor. Also make GBTClassificationModel inherit from JavaClassificationModel instead of prediction model so it will have numClasses. ## How was this patch tested? Add tests in doctest Author: Huaxin Gao Closes #21413 from huaxingao/spark-23161. --- python/pyspark/ml/classification.py | 35 ++++++++++++--- python/pyspark/ml/regression.py | 70 ++++++++++++++++++----------- 2 files changed, 74 insertions(+), 31 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 424ecfd89b060..1754c48937a62 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1131,6 +1131,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): @@ -1193,6 +1200,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) + >>> gbt.getFeatureSubsetStrategy() + 'all' >>> model = gbt.fit(td) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1226,6 +1235,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol ... ["indexed", "features"]) >>> model.evaluateEachIteration(validation) [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] + >>> model.numClasses + 2 .. versionadded:: 1.4.0 """ @@ -1244,19 +1255,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, + featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1265,12 +1279,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1293,8 +1309,15 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + -class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, +class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dd0b62f184d26..dba0e57b01a0b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -602,6 +602,14 @@ class TreeEnsembleParams(DecisionTreeParams): "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat) + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", + typeConverter=TypeConverters.toString) + def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -619,6 +627,22 @@ def getSubsamplingRate(self): """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + + .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0. + """ + return self._set(featureSubsetStrategy=value) + + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + class TreeRegressorParams(Params): """ @@ -654,14 +678,8 @@ class RandomForestParams(TreeEnsembleParams): Private class to track supported random forest parameters. """ - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt) - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -680,20 +698,6 @@ def getNumTrees(self): """ return self.getOrDefault(self.numTrees) - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - return self._set(featureSubsetStrategy=value) - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class GBTParams(TreeEnsembleParams): """ @@ -981,6 +985,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): @@ -1029,6 +1040,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> print(gbt.getImpurity()) variance + >>> print(gbt.getFeatureSubsetStrategy()) + all >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1079,20 +1092,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance"): + impurity="variance", featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance") + impurity="variance", featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1102,13 +1115,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance"): + impuriy="variance", featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1131,6 +1144,13 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ From 1b36f148891ac41ef36a40366f87dd5405cb3751 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 11:18:04 -0700 Subject: [PATCH 585/593] [SPARK-23901][SQL] Add masking functions ## What changes were proposed in this pull request? The PR adds the masking function as they are described in Hive's documentation: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-DataMaskingFunctions. This means that only `string`s are accepted as parameter for the masking functions. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21246 from mgaido91/SPARK-23901. --- .../expressions/MaskExpressionsUtils.java | 80 +++ .../catalyst/analysis/FunctionRegistry.scala | 8 + .../expressions/maskExpressions.scala | 569 ++++++++++++++++++ .../expressions/MaskExpressionsSuite.scala | 236 ++++++++ .../org/apache/spark/sql/functions.scala | 119 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 107 ++++ 6 files changed, 1119 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java new file mode 100644 index 0000000000000..05879902a4ed9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Contains all the Utils methods used in the masking expressions. + */ +public class MaskExpressionsUtils { + static final int UNMASKED_VAL = -1; + + /** + * Returns the masking character for {@param c} or {@param c} is it should not be masked. + * @param c the character to transform + * @param maskedUpperChar the character to use instead of a uppercase letter + * @param maskedLowerChar the character to use instead of a lowercase letter + * @param maskedDigitChar the character to use instead of a digit + * @param maskedOtherChar the character to use instead of a any other character + * @return masking character for {@param c} + */ + public static int transformChar( + final int c, + int maskedUpperChar, + int maskedLowerChar, + int maskedDigitChar, + int maskedOtherChar) { + switch(Character.getType(c)) { + case Character.UPPERCASE_LETTER: + if(maskedUpperChar != UNMASKED_VAL) { + return maskedUpperChar; + } + break; + + case Character.LOWERCASE_LETTER: + if(maskedLowerChar != UNMASKED_VAL) { + return maskedLowerChar; + } + break; + + case Character.DECIMAL_DIGIT_NUMBER: + if(maskedDigitChar != UNMASKED_VAL) { + return maskedDigitChar; + } + break; + + default: + if(maskedOtherChar != UNMASKED_VAL) { + return maskedOtherChar; + } + break; + } + + return c; + } + + /** + * Returns the replacement char to use according to the {@param rep} specified by the user and + * the {@param def} default. + */ + public static int getReplacementChar(String rep, int def) { + if (rep != null && rep.length() > 0) { + return rep.codePointAt(0); + } + return def; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1134a8866dc13..23a4a440fac23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -432,6 +432,14 @@ object FunctionRegistry { expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, + // mask functions + expression[Mask]("mask"), + expression[MaskFirstN]("mask_first_n"), + expression[MaskLastN]("mask_last_n"), + expression[MaskShowFirstN]("mask_show_first_n"), + expression[MaskShowLastN]("mask_show_last_n"), + expression[MaskHash]("mask_hash"), + // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala new file mode 100644 index 0000000000000..276a57266a6e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ +import org.apache.spark.sql.catalyst.expressions.MaskLike._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait MaskLike { + def upper: String + def lower: String + def digit: String + + protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) + protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) + protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) + + protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName + + def inputStringLengthCode(inputString: String, length: String): String = { + s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + } + + def appendMaskedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, + | $upperReplacement, $lowerReplacement, + | $digitReplacement, $defaultMaskedOther)); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendUnchangedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($codePoint); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendMaskedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(transformChar( + codePoint, + upperReplacement, + lowerReplacement, + digitReplacement, + defaultMaskedOther)) + offset += Character.charCount(codePoint) + } + offset + } + + def appendUnchangedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(codePoint) + offset += Character.charCount(codePoint) + } + offset + } +} + +trait MaskLikeWithN extends MaskLike { + def n: Int + protected lazy val charCount: Int = if (n < 0) 0 else n +} + +/** + * Utils for mask operations. + */ +object MaskLike { + val defaultCharCount = 4 + val defaultMaskedUppercase: Int = 'X' + val defaultMaskedLowercase: Int = 'x' + val defaultMaskedDigit: Int = 'n' + val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL + + def extractCharCount(e: Expression): Int = e match { + case Literal(i, IntegerType | NullType) => + if (i == null) defaultCharCount else i.asInstanceOf[Int] + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } + + def extractReplacement(e: Expression): String = e match { + case Literal(s, StringType | NullType) => if (s == null) null else s.toString + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${StringType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } +} + +/** + * Masks the input string. Additional parameters can be set to change the masking chars for + * uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); + llll-UUUU-####-#### + """) +// scalastyle:on line.size.limit +case class Mask(child: Expression, upper: String, lower: String, digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLike { + + def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) + + def this(child: Expression, upper: Expression) = + this(child, extractReplacement(upper), null, null) + + def this(child: Expression, upper: Expression, lower: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), null) + + def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val sb = new java.lang.StringBuilder(length) + appendMaskedToStringBuilder(sb, str, 0, length) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |StringBuilder $sb = new StringBuilder($length); + |${CodeGenerator.JAVA_INT} $offset = 0; + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} + |${ev.value} = UTF8String.fromString($sb.toString()); + """.stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) +} + +/** + * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-5678-8765-4321 + """) +// scalastyle:on line.size.limit +case class MaskFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_first_n" +} + +/** + * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-5678-8765-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? + | 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_last_n" +} + +/** + * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-nnnn-nnnn-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_first_n" +} + +/** + * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-nnnn-nnnn-4321 + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_last_n" +} + +/** + * Returns a hashed value based on str. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321"); + 60c713f5ec6912229d2060df1c322776 + """) +// scalastyle:on line.size.limit +case class MaskHash(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def nullSafeEval(input: Any): Any = { + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") + s""" + |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_hash" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala new file mode 100644 index 0000000000000..4d69dc32ace82 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.{IntegerType, StringType} + +class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("mask") { + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") + checkEvaluation( + new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-####-####") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), + "llll-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal(null, StringType)), null) + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") + checkEvaluation(new Mask( + Literal("abcd-EFGH-8765-4321"), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), + "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("")), "") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") + checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), + "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") + // scalastyle:on nonascii + intercept[AnalysisException] { + checkEvaluation(new Mask(Literal(""), Literal(1)), "") + } + } + + test("mask_first_n") { + checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UFGH-8765") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UFGH-8765-4321") + checkEvaluation( + new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XFGH-8765-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-EFGH-8765-4321") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("")), "") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-EFGH-8765-4321") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xxxxo World") + // scalastyle:on nonascii + } + + test("mask_last_n") { + checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), + "abcd-EFGU-lU#l") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EFGU-####") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), + "abcd-EFGX-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") + } + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal(null, StringType)), null) + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), + "abcd-EFUU-nnnn-nnnn") + checkEvaluation(new MaskLastN(Literal("")), "") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), + "abcx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), + "あ, 𠀋, Hello Xxxxx あ 𠀋") + // scalastyle:on nonascii + } + + test("mask_show_first_n") { + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), + "abcd-EUUU-####-lU#l") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EUUU-####-####") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "abcd-EXXX-nnnn-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-UUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("")), "") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "abcd-XXXX-nnnn-nnnn") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hellx Xxxxx") + // scalastyle:on nonascii + } + + test("mask_show_last_n") { + checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UUUH-8765") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-###5-4321") + checkEvaluation( + new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XXXX-nnn5-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-UUUU-nnnn-4321") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("")), "") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-XXXX-nnnn-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xello World") + // scalastyle:on nonascii + } + + test("mask_hash") { + checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") + checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") + checkEvaluation(MaskHash(Literal(null, StringType)), null) + // scalastyle:off nonascii + checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") + // scalastyle:on nonascii + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5ab9cb3fb86a5..443ba2aa3757d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3499,6 +3499,125 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Mask functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns a string which is the masked representation of the input. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column): Column = withExpr { new Mask(e.expr) } + + /** + * Returns a string which is the masked representation of the input, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { + Mask(e.expr, upper, lower, digit) + } + + /** + * Returns a string with the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } + + /** + * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } + + /** + * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n(e: Column, n: Int): Column = withExpr { + new MaskShowFirstN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n(e: Column, n: Int): Column = withExpr { + new MaskShowLastN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a hashed value based on the input column. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 79e743d961af8..cc8bad4ded53e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,6 +276,113 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("mask functions") { + val df = Seq("TestString-123", "", null).toDF("a") + checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4)), + Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4)), + Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_hash($"a")), + Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), + Row("d41d8cd98f00b204e9800998ecf8427e"), + Row(null))) + + checkAnswer(df.select(mask($"a", "U", "l", "#")), + Seq(Row("UlllUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), + Seq(Row("TestString-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), + Seq(Row("TestUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllUlllll-123"), Row(""), Row(null))) + + checkAnswer( + df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), + Row("", "", "", ""), + Row(null, null, null, null))) + checkAnswer(sql("select mask(null)"), Row(null)) + checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) + intercept[AnalysisException] { + checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + } + + checkAnswer( + df.selectExpr( + "mask_first_n(a)", + "mask_first_n(a, 6)", + "mask_first_n(a, 6, 'U')", + "mask_first_n(a, 6, 'U', 'l')", + "mask_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", + "UlllUlring-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) + } + + checkAnswer( + df.selectExpr( + "mask_last_n(a)", + "mask_last_n(a, 6)", + "mask_last_n(a, 6, 'U')", + "mask_last_n(a, 6, 'U', 'l')", + "mask_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", + "TestStrill-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_first_n(a)", + "mask_show_first_n(a, 6)", + "mask_show_first_n(a, 6, 'U')", + "mask_show_first_n(a, 6, 'U', 'l')", + "mask_show_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", + "TestStllll-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_last_n(a)", + "mask_show_last_n(a, 6)", + "mask_show_last_n(a, 6, 'U')", + "mask_show_last_n(a, 6, 'U', 'l')", + "mask_show_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", + "UlllUlllng-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) + } + + checkAnswer(sql("select mask_hash(null)"), Row(null)) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 24ef7fbfa94fc85663eaf49cc574f81387c66c62 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 15:31:40 -0700 Subject: [PATCH 586/593] [SPARK-24276][SQL] Order of literals in IN should not affect semantic equality ## What changes were proposed in this pull request? When two `In` operators are created with the same list of values, but different order, we are considering them as semantically different. This is wrong, since they have the same semantic meaning. The PR adds a canonicalization rule which orders the literals in the `In` operator so the semantic equality works properly. ## How was this patch tested? added UT Author: Marco Gaido Closes #21331 from mgaido91/SPARK-24276. --- .../catalyst/expressions/Canonicalize.scala | 6 +++ .../expressions/CanonicalizeSuite.scala | 53 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index d848ba18356d3..7541f527a52a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + * - Elements in [[In]] are reordered by `hashCode`. */ object Canonicalize { def execute(e: Expression): Expression = { @@ -85,6 +86,11 @@ object Canonicalize { case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + // order the list in the In operator + // In subqueries contain only one element of type ListQuery. So checking that the length > 1 + // we are not reordering In subqueries. + case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case _ => e } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala new file mode 100644 index 0000000000000..28e6940f3cca3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Range + +class CanonicalizeSuite extends SparkFunSuite { + + test("SPARK-24276: IN expression with different order are semantically equal") { + val range = Range(1, 1, 1, 1) + val idAttr = range.output.head + + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + + assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) + assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) + + assert(range.where(in1).sameResult(range.where(in2))) + assert(!range.where(in1).sameResult(range.where(in3))) + + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(2), Literal(1))))) + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + CreateArray(Seq(Literal(1), Literal(2))))) + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(3), Literal(1))))) + + assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) + assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash()) + + assert(range.where(arrays1).sameResult(range.where(arrays2))) + assert(!range.where(arrays1).sameResult(range.where(arrays3))) + } +} From 0053e153faaa76ea38c845adab137d5be970e5af Mon Sep 17 00:00:00 2001 From: William Sheu Date: Wed, 30 May 2018 22:37:27 -0700 Subject: [PATCH 587/593] [SPARK-24337][CORE] Improve error messages for Spark conf values ## What changes were proposed in this pull request? Improve the exception messages when retrieving Spark conf values to include the key name when the value is invalid. ## How was this patch tested? Unit tests for all get* operations in SparkConf that require a specific value format Author: William Sheu Closes #21454 from PenguinToast/SPARK-24337-spark-config-errors. --- .../scala/org/apache/spark/SparkConf.scala | 85 ++++++++++++++----- .../org/apache/spark/SparkConfSuite.scala | 32 +++++++ 2 files changed, 96 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index dab409572646f..6c4c5c94cfa28 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -265,16 +265,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String): Long = { + def getTimeAsSeconds(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key)) } /** * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String, defaultValue: String): Long = { + def getTimeAsSeconds(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key, defaultValue)) } @@ -282,16 +284,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String): Long = { + def getTimeAsMs(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key)) } /** * Get a time parameter as milliseconds, falling back to a default if not set. If no * suffix is provided then milliseconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String, defaultValue: String): Long = { + def getTimeAsMs(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key, defaultValue)) } @@ -299,23 +303,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String): Long = { + def getSizeAsBytes(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key)) } /** * Get a size parameter as bytes, falling back to a default if not set. If no * suffix is provided then bytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: String): Long = { + def getSizeAsBytes(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue)) } /** * Get a size parameter as bytes, falling back to a default if not set. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: Long): Long = { + def getSizeAsBytes(key: String, defaultValue: Long): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue + "B")) } @@ -323,16 +330,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String): Long = { + def getSizeAsKb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key)) } /** * Get a size parameter as Kibibytes, falling back to a default if not set. If no * suffix is provided then Kibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String, defaultValue: String): Long = { + def getSizeAsKb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key, defaultValue)) } @@ -340,16 +349,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String): Long = { + def getSizeAsMb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key)) } /** * Get a size parameter as Mebibytes, falling back to a default if not set. If no * suffix is provided then Mebibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String, defaultValue: String): Long = { + def getSizeAsMb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key, defaultValue)) } @@ -357,16 +368,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String): Long = { + def getSizeAsGb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key)) } /** * Get a size parameter as Gibibytes, falling back to a default if not set. If no * suffix is provided then Gibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String, defaultValue: String): Long = { + def getSizeAsGb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key, defaultValue)) } @@ -394,23 +407,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } - /** Get a parameter as an integer, falling back to a default if not set */ - def getInt(key: String, defaultValue: Int): Int = { + /** + * Get a parameter as an integer, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as an integer + */ + def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) { getOption(key).map(_.toInt).getOrElse(defaultValue) } - /** Get a parameter as a long, falling back to a default if not set */ - def getLong(key: String, defaultValue: Long): Long = { + /** + * Get a parameter as a long, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as a long + */ + def getLong(key: String, defaultValue: Long): Long = catchIllegalValue(key) { getOption(key).map(_.toLong).getOrElse(defaultValue) } - /** Get a parameter as a double, falling back to a default if not set */ - def getDouble(key: String, defaultValue: Double): Double = { + /** + * Get a parameter as a double, falling back to a default if not ste + * @throws NumberFormatException If the value cannot be interpreted as a double + */ + def getDouble(key: String, defaultValue: Double): Double = catchIllegalValue(key) { getOption(key).map(_.toDouble).getOrElse(defaultValue) } - /** Get a parameter as a boolean, falling back to a default if not set */ - def getBoolean(key: String, defaultValue: Boolean): Boolean = { + /** + * Get a parameter as a boolean, falling back to a default if not set + * @throws IllegalArgumentException If the value cannot be interpreted as a boolean + */ + def getBoolean(key: String, defaultValue: Boolean): Boolean = catchIllegalValue(key) { getOption(key).map(_.toBoolean).getOrElse(defaultValue) } @@ -448,6 +473,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def getenv(name: String): String = System.getenv(name) + /** + * Wrapper method for get() methods which require some specific value format. This catches + * any [[NumberFormatException]] or [[IllegalArgumentException]] and re-raises it with the + * incorrectly configured key in the exception message. + */ + private def catchIllegalValue[T](key: String)(getValue: => T): T = { + try { + getValue + } catch { + case e: NumberFormatException => + // NumberFormatException doesn't have a constructor that takes a cause for some reason. + throw new NumberFormatException(s"Illegal value for config key $key: ${e.getMessage}") + .initCause(e) + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Illegal value for config key $key: ${e.getMessage}", e) + } + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index bff808eb540ac..0d06b02e74e34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + val defaultIllegalValue = "SomeIllegalValue" + val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( + "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), + "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)), + "getTimeAsMs" -> (_.getTimeAsMs(_)), + "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)), + "getSizeAsBytes" -> (_.getSizeAsBytes(_)), + "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)), + "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)), + "getSizeAsKb" -> (_.getSizeAsKb(_)), + "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)), + "getSizeAsMb" -> (_.getSizeAsMb(_)), + "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)), + "getSizeAsGb" -> (_.getSizeAsGb(_)), + "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)), + "getInt" -> (_.getInt(_, 0)), + "getLong" -> (_.getLong(_, 0L)), + "getDouble" -> (_.getDouble(_, 0.0)), + "getBoolean" -> (_.getBoolean(_, false)) + ) + + illegalValueTests.foreach { case (name, getValue) => + test(s"SPARK-24337: $name throws an useful error message with key name") { + val key = "SomeKey" + val conf = new SparkConf() + conf.set(key, "SomeInvalidValue") + val thrown = intercept[IllegalArgumentException] { + getValue(conf, key) + } + assert(thrown.getMessage.contains(key)) + } + } } class Class1 {} From 90ae98d1accb3e4b7d381de072257bdece8dd7e0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 31 May 2018 06:53:10 -0700 Subject: [PATCH 588/593] [SPARK-24146][PYSPARK][ML] spark.ml parity for sequential pattern mining - PrefixSpan: Python API ## What changes were proposed in this pull request? spark.ml parity for sequential pattern mining - PrefixSpan: Python API ## How was this patch tested? doctests Author: WeichenXu Closes #21265 from WeichenXu123/prefix_span_py. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 6 +- python/pyspark/ml/fpm.py | 104 +++++++++++++++++- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 41716c621ca98..bd1c1a8885201 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + "sequential pattern. Sequential pattern that appears more than " + - "(minSupport * size-of-the-dataset)." + + "(minSupport * size-of-the-dataset) " + "times will be output.", ParamValidators.gtEq(0.0)) /** @group getParam */ @@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. * * @param dataset A dataset or a dataframe containing a sequence column which is - * {{{Seq[Seq[_]]}}} type + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: - * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) * - `freq: Long` */ @Since("2.4.0") diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354d..fd19fd96c4df6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,8 +16,9 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * __all__ = ["FPGrowth", "FPGrowthModel"] @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) From 698b9a0981f0ec322e15d6ac89cc38c8f49ed33d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 31 May 2018 09:34:39 -0700 Subject: [PATCH 589/593] [WEBUI] Avoid possibility of script in query param keys As discussed separately, this avoids the possibility of XSS on certain request param keys. CC vanzin Author: Sean Owen Closes #21464 from srowen/XSS2. --- .../src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 4 +++- core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index f651fe97c2cd5..178d2c8d1a10a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -206,7 +206,9 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b8b20db1fa407..56e4d6838a99a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -43,7 +43,9 @@ private[ui] class StageTableBase( killEnabled: Boolean, isFailedStage: Boolean) { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) From 7a82e93b349b4f414f2075dd5add8e4ed72fe357 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 31 May 2018 10:05:20 -0700 Subject: [PATCH 590/593] [SPARK-24414][UI] Calculate the correct number of tasks for a stage. This change takes into account all non-pending tasks when calculating the number of tasks to be shown. This also means that when the stage is pending, the task table (or, in fact, most of the data in the stage page) will not be rendered. I also fixed the label when the known number of tasks is larger than the recorded number of tasks (it was inverted). Author: Marcelo Vanzin Closes #21457 from vanzin/SPARK-24414. --- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2575914121c39..d4e6a7bc3effa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -117,8 +117,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks + val totalTasks = taskCount(stageData) if (totalTasks == 0) { val content =
    @@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${storedTasks}" + s"$storedTasks, showing ${totalTasks}" } val summary = @@ -686,7 +685,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numTasks + override def dataSize: Int = taskCount(stage) override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1052,4 +1051,9 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } + def taskCount(stageData: StageData): Int = { + stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + + stageData.numKilledTasks + } + } From 223df5d9d4fbf48db017edb41f9b7e4033679f35 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 31 May 2018 11:23:57 -0700 Subject: [PATCH 591/593] [SPARK-24397][PYSPARK] Added TaskContext.getLocalProperty(key) in Python ## What changes were proposed in this pull request? This adds a new API `TaskContext.getLocalProperty(key)` to the Python TaskContext. It mirrors the Java TaskContext API of returning a string value if the key exists, or None if the key does not exist. ## How was this patch tested? New test added. Author: Tathagata Das Closes #21437 from tdas/SPARK-24397. --- .../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++ python/pyspark/taskcontext.py | 7 +++++++ python/pyspark/tests.py | 14 ++++++++++++++ python/pyspark/worker.py | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e0eb0b4..41eac10d9b267 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9e75e78..63ae1f30e17ca 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -34,6 +34,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +89,9 @@ def taskAttemptId(self): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3b37cc028c1b7..30723b8e15b36 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -574,6 +574,20 @@ def test_tc_on_driver(self): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + class RDDTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5d2e58bef6466..fbcb8af8bfb24 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -222,6 +222,12 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() From cc976f6cb858adb5f52987b56dda54769915ce50 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 31 May 2018 11:38:23 -0700 Subject: [PATCH 592/593] [SPARK-23900][SQL] format_number support user specifed format as argument ## What changes were proposed in this pull request? `format_number` support user specifed format as argument. For example: ```sql spark-sql> SELECT format_number(12332.123456, '##################.###'); 12332.123 ``` ## How was this patch tested? unit test Author: Yuming Wang Closes #21010 from wangyum/SPARK-23900. --- .../expressions/stringExpressions.scala | 142 ++++++++++++------ .../expressions/StringExpressionsSuite.scala | 24 +++ 2 files changed, 116 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9823b2fc5ad97..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1916,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression) usage = """ _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If `expr2` is 0, the result has no decimal point or fractional part. + `expr2` also accept a user specified format. This is supposed to function like MySQL's FORMAT. """, examples = """ Examples: > SELECT _FUNC_(12332.123456, 4); 12,332.1235 + > SELECT _FUNC_(12332.123456, '##################.###'); + 12332.123 """) case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -1930,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(IntegerType, StringType)) + + private val defaultFormat = "#,###,###,###,###,###,##0" // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Option[Int] = None + private var lastDIntValue: Option[Int] = None + + @transient + private var lastDStringValue: Option[String] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @@ -1950,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression) private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - lastDValue match { - case Some(last) if last == dValue => - // use the current pattern - case _ => - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } + right.dataType match { + case IntegerType => + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null } - lastDValue = Some(dValue) + lastDIntValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append(defaultFormat) + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + + lastDIntValue = Some(dValue) - numberFormat.applyLocalizedPattern(pattern.toString) + numberFormat.applyLocalizedPattern(pattern.toString) + } + case StringType => + val dValue = dObject.asInstanceOf[UTF8String].toString + lastDStringValue match { + case Some(last) if last == dValue => + case _ => + pattern.delete(0, pattern.length) + lastDStringValue = Some(dValue) + if (dValue.isEmpty) { + numberFormat.applyLocalizedPattern(defaultFormat) + } else { + numberFormat.applyLocalizedPattern(dValue) + } + } } x.dataType match { @@ -2008,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + right.dataType match { + case IntegerType => + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val i = ctx.freshName("i") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("$defaultFormat"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $lastDValue = $d; + $numberFormat.applyLocalizedPattern($pattern.toString()); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } else { + ${ev.value} = null; + ${ev.isNull} = true; } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); - } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.value} = null; - ${ev.isNull} = true; - } - """ + """ + case StringType => + val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") + val dValue = ctx.freshName("dValue") + s""" + String $dValue = $d.toString(); + if (!$dValue.equals($lastDValue)) { + $lastDValue = $dValue; + if ($dValue.isEmpty()) { + $numberFormat.applyLocalizedPattern("$defaultFormat"); + } else { + $numberFormat.applyLocalizedPattern($dValue); + } + } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + """ + } }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f1a6f9b8889fa..aa334e040d5fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(12831273.23481d), + Literal("###,###,###,###,###.###")), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")), + "123,123,324,123") + checkEvaluation( + FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)), + Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null) + assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")), + "12,332") + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, StringType)), null) + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("find in set") { From 21e1fc7d4aed688d7b685be6ce93f76752159c98 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 31 May 2018 14:28:33 -0700 Subject: [PATCH 593/593] [SPARK-24232][K8S] Add support for secret env vars ## What changes were proposed in this pull request? * Allows to refer a secret as an env var. * Introduces new config properties in the form: spark.kubernetes{driver,executor}.secretKeyRef.ENV_NAME=name:key ENV_NAME is case sensitive. * Updates docs. * Adds required unit tests. ## How was this patch tested? Manually tested and confirmed that the secrets exist in driver's and executor's container env. Also job finished successfully. First created a secret with the following yaml: ``` apiVersion: v1 kind: Secret metadata: name: test-secret data: username: c3RhdnJvcwo= password: Mzk1MjgkdmRnN0pi ------- $ echo -n 'stavros' | base64 c3RhdnJvcw== $ echo -n '39528$vdg7Jb' | base64 MWYyZDFlMmU2N2Rm ``` Run a job as follows: ```./bin/spark-submit \ --master k8s://http://localhost:9000 \ --deploy-mode cluster \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=1 \ --conf spark.kubernetes.container.image=skonto/spark:k8envs3 \ --conf spark.kubernetes.driver.secretKeyRef.MY_USERNAME=test-secret:username \ --conf spark.kubernetes.driver.secretKeyRef.My_password=test-secret:password \ --conf spark.kubernetes.executor.secretKeyRef.MY_USERNAME=test-secret:username \ --conf spark.kubernetes.executor.secretKeyRef.My_password=test-secret:password \ local:///opt/spark/examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar 10000 ``` Secret loaded correctly at the driver container: ![image](https://user-images.githubusercontent.com/7945591/40174346-7fee70c8-59dd-11e8-8705-995a5472716f.png) Also if I log into the exec container: kubectl exec -it spark-pi-1526555613156-exec-1 bash bash-4.4# env > SPARK_EXECUTOR_MEMORY=1g > SPARK_EXECUTOR_CORES=1 > LANG=C.UTF-8 > HOSTNAME=spark-pi-1526555613156-exec-1 > SPARK_APPLICATION_ID=spark-application-1526555618626 > **MY_USERNAME=stavros** > > JAVA_HOME=/usr/lib/jvm/java-1.8-openjdk > KUBERNETES_PORT_443_TCP_PROTO=tcp > KUBERNETES_PORT_443_TCP_ADDR=10.100.0.1 > JAVA_VERSION=8u151 > KUBERNETES_PORT=tcp://10.100.0.1:443 > PWD=/opt/spark/work-dir > HOME=/root > SPARK_LOCAL_DIRS=/var/data/spark-b569b0ae-b7ef-4f91-bcd5-0f55535d3564 > KUBERNETES_SERVICE_PORT_HTTPS=443 > KUBERNETES_PORT_443_TCP_PORT=443 > SPARK_HOME=/opt/spark > SPARK_DRIVER_URL=spark://CoarseGrainedSchedulerspark-pi-1526555613156-driver-svc.default.svc:7078 > KUBERNETES_PORT_443_TCP=tcp://10.100.0.1:443 > SPARK_EXECUTOR_POD_IP=9.0.9.77 > TERM=xterm > SPARK_EXECUTOR_ID=1 > SHLVL=1 > KUBERNETES_SERVICE_PORT=443 > SPARK_CONF_DIR=/opt/spark/conf > PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/lib/jvm/java-1.8-openjdk/jre/bin:/usr/lib/jvm/java-1.8-openjdk/bin > JAVA_ALPINE_VERSION=8.151.12-r0 > KUBERNETES_SERVICE_HOST=10.100.0.1 > **My_password=39528$vdg7Jb** > _=/usr/bin/env > Author: Stavros Kontopoulos Closes #21317 from skonto/k8s-fix-env-secrets. --- docs/running-on-kubernetes.md | 22 +++++++ .../org/apache/spark/deploy/k8s/Config.scala | 2 + .../spark/deploy/k8s/KubernetesConf.scala | 11 +++- .../k8s/features/EnvSecretsFeatureStep.scala | 57 ++++++++++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 11 +++- .../k8s/KubernetesExecutorBuilder.scala | 12 +++- .../deploy/k8s/KubernetesConfSuite.scala | 12 +++- .../BasicDriverFeatureStepSuite.scala | 2 + .../BasicExecutorFeatureStepSuite.scala | 3 + ...ubernetesCredentialsFeatureStepSuite.scala | 3 + .../DriverServiceFeatureStepSuite.scala | 6 ++ .../features/EnvSecretsFeatureStepSuite.scala | 59 +++++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 7 ++- .../features/LocalDirsFeatureStepSuite.scala | 1 + .../MountSecretsFeatureStepSuite.scala | 1 + .../spark/deploy/k8s/submit/ClientSuite.scala | 1 + .../submit/KubernetesDriverBuilderSuite.scala | 13 +++- .../k8s/KubernetesExecutorBuilderSuite.scala | 11 +++- 18 files changed, 222 insertions(+), 12 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e9e1f3e280609..a4b2b98b0b649 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -140,6 +140,12 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` +To use a secret through an environment variable use the following options to the `spark-submit` command: +``` +--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key +--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key +``` + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -602,4 +608,20 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. + + spark.kubernetes.driver.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.executor.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4086970ffb256..560dedf431b08 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -162,10 +162,12 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 77b634ddfabcc..5a944187a7096 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -54,6 +54,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleLabels: Map[String, String], roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], + roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -129,6 +130,8 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) @@ -140,6 +143,7 @@ private[spark] object KubernetesConf { driverLabels, driverAnnotations, driverSecretNamesToMountPaths, + driverSecretEnvNamesToKeyRefs, driverEnvs) } @@ -167,8 +171,10 @@ private[spark] object KubernetesConf { executorCustomLabels val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap KubernetesConf( @@ -178,7 +184,8 @@ private[spark] object KubernetesConf { appId, executorLabels, executorAnnotations, - executorSecrets, + executorMountSecrets, + executorEnvSecrets, executorEnv) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala new file mode 100644 index 0000000000000..03ff7d48420ff --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class EnvSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedEnvSecrets = kubernetesConf + .roleSecretEnvNamesToKeyRefs + .map{ case (envName, keyRef) => + // Keyref parts + val keyRefParts = keyRef.split(":") + require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.") + val name = keyRefParts(0) + val key = keyRefParts(1) + new EnvVarBuilder() + .withName(envName) + .withNewValueFrom() + .withNewSecretKeyRef() + .withKey(key) + .withName(name) + .endSecretKeyRef() + .endValueFrom() + .build() + } + + val containerWithEnvVars = new ContainerBuilder(pod.container) + .addAllToEnv(addedEnvSecrets.toSeq.asJava) + .build() + SparkPod(pod.pod, containerWithEnvVars) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 10b0154466a3a..fdc5eb0d75832 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -30,6 +30,9 @@ private[spark] class KubernetesDriverBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -41,10 +44,14 @@ private[spark] class KubernetesDriverBuilder( provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { val configuredPod = feature.configurePod(spec.pod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d8f63d57574fb..d5e1de36a58df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = @@ -25,6 +25,9 @@ private[spark] class KubernetesExecutorBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -32,9 +35,14 @@ private[spark] class KubernetesExecutorBuilder( def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index f10202f7a3546..3d23e1cb90fd2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -40,6 +40,9 @@ class KubernetesConfSuite extends SparkFunSuite { private val SECRET_NAMES_TO_MOUNT_PATHS = Map( "secret1" -> "/mnt/secrets/secret1", "secret2" -> "/mnt/secrets/secret2") + private val SECRET_ENV_VARS = Map( + "envName1" -> "name1:key1", + "envName2" -> "name2:key2") private val CUSTOM_ENVS = Map( "customEnvKey1" -> "customEnvValue1", "customEnvKey2" -> "customEnvValue2") @@ -103,6 +106,9 @@ class KubernetesConfSuite extends SparkFunSuite { SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value) + } CUSTOM_ENVS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) } @@ -121,6 +127,7 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) } @@ -155,6 +162,9 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_ANNOTATIONS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value) + } SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) } @@ -170,6 +180,6 @@ class KubernetesConfSuite extends SparkFunSuite { SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) } - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eee85b8baa730..b2813d8b3265d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -69,6 +69,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, DRIVER_ENVS) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -138,6 +139,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, Map.empty) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a764f7630b5c8..9182134b3337c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -87,6 +87,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) val executor = step.configurePod(SparkPod.initialPod()) @@ -124,6 +125,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -142,6 +144,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map("qux" -> "quux"))) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 9f817d3bfc79a..f81894f8055f1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -59,6 +59,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -88,6 +89,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -124,6 +126,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index c299d56865ec0..f265522a8823a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -65,6 +65,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -94,6 +95,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -113,6 +115,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -141,6 +144,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) val driverService = configurationStep @@ -166,6 +170,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver bind address should not be allowed.") @@ -189,6 +194,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala new file mode 100644 index 0000000000000..8b0b2d0739c76 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class EnvSecretsFeatureStepSuite extends SparkFunSuite{ + private val KEY_REF_NAME_FOO = "foo" + private val KEY_REF_NAME_BAR = "bar" + private val KEY_REF_KEY_FOO = "key_foo" + private val KEY_REF_KEY_BAR = "key_bar" + private val ENV_NAME_FOO = "MY_FOO" + private val ENV_NAME_BAR = "MY_bar" + + test("sets up all keyRefs") { + val baseDriverPod = SparkPod.initialPod() + val envVarsToKeys = Map( + ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", + ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + Map.empty, + envVarsToKeys, + Map.empty) + + val step = new EnvSecretsFeatureStep(kubernetesConf) + val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container + + val expectedVars = + Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") + + expectedVars.foreach { envName => + assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 27bff74ce38af..f90380e30e52a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -58,4 +60,7 @@ object KubernetesFeaturesTestUtils { .build()) } + def containerHasEnvVar(container: Container, envVarName: String): Boolean = { + container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 91e184b84b86e..2542a02d37766 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9d02f56cc206d..9155793774123 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -41,6 +41,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, secretNamesToMountPaths, + Map.empty, Map.empty) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index c1b203e03a357..0775338098a13 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -142,6 +142,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index a511d254d2175..cb724068ea4f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -27,6 +27,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -43,12 +44,16 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, _ => secretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Apply fundamental steps all the time.") { @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("EnvName" -> "SecretName:secretKey"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -93,7 +100,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE + ) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 9ee86b5a423a9..753cd30a237f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,23 +20,27 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Basic steps are consistently applied.") { @@ -49,6 +53,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -64,12 +69,14 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("secret-name" -> "secret-key"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE) } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = {